Passed
Push — master ( af9ccf...cbd9f5 )
by
unknown
02:26
created

SupportVectorMachine::train()   B

Complexity

Conditions 2
Paths 2

Size

Total Lines 25
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 25
rs 8.8571
c 0
b 0
f 0
cc 2
eloc 15
nc 2
nop 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\SupportVectorMachine;
6
7
use Phpml\Exception\InvalidArgumentException;
8
use Phpml\Exception\InvalidOperationException;
9
use Phpml\Exception\LibsvmCommandException;
10
use Phpml\Helper\Trainable;
11
12
class SupportVectorMachine
13
{
14
    use Trainable;
15
16
    /**
17
     * @var int
18
     */
19
    private $type;
20
21
    /**
22
     * @var int
23
     */
24
    private $kernel;
25
26
    /**
27
     * @var float
28
     */
29
    private $cost;
30
31
    /**
32
     * @var float
33
     */
34
    private $nu;
35
36
    /**
37
     * @var int
38
     */
39
    private $degree;
40
41
    /**
42
     * @var float|null
43
     */
44
    private $gamma;
45
46
    /**
47
     * @var float
48
     */
49
    private $coef0;
50
51
    /**
52
     * @var float
53
     */
54
    private $epsilon;
55
56
    /**
57
     * @var float
58
     */
59
    private $tolerance;
60
61
    /**
62
     * @var int
63
     */
64
    private $cacheSize;
65
66
    /**
67
     * @var bool
68
     */
69
    private $shrinking;
70
71
    /**
72
     * @var bool
73
     */
74
    private $probabilityEstimates;
75
76
    /**
77
     * @var string
78
     */
79
    private $binPath;
80
81
    /**
82
     * @var string
83
     */
84
    private $varPath;
85
86
    /**
87
     * @var string
88
     */
89
    private $model;
90
91
    /**
92
     * @var array
93
     */
94
    private $targets = [];
95
96
    public function __construct(
97
        int $type,
98
        int $kernel,
99
        float $cost = 1.0,
100
        float $nu = 0.5,
101
        int $degree = 3,
102
        ?float $gamma = null,
103
        float $coef0 = 0.0,
104
        float $epsilon = 0.1,
105
        float $tolerance = 0.001,
106
        int $cacheSize = 100,
107
        bool $shrinking = true,
108
        bool $probabilityEstimates = false
109
    ) {
110
        $this->type = $type;
111
        $this->kernel = $kernel;
112
        $this->cost = $cost;
113
        $this->nu = $nu;
114
        $this->degree = $degree;
115
        $this->gamma = $gamma;
116
        $this->coef0 = $coef0;
117
        $this->epsilon = $epsilon;
118
        $this->tolerance = $tolerance;
119
        $this->cacheSize = $cacheSize;
120
        $this->shrinking = $shrinking;
121
        $this->probabilityEstimates = $probabilityEstimates;
122
123
        $rootPath = realpath(implode(DIRECTORY_SEPARATOR, [__DIR__, '..', '..'])).DIRECTORY_SEPARATOR;
124
125
        $this->binPath = $rootPath.'bin'.DIRECTORY_SEPARATOR.'libsvm'.DIRECTORY_SEPARATOR;
126
        $this->varPath = $rootPath.'var'.DIRECTORY_SEPARATOR;
127
    }
128
129
    public function setBinPath(string $binPath): void
130
    {
131
        $this->ensureDirectorySeparator($binPath);
132
        $this->verifyBinPath($binPath);
133
134
        $this->binPath = $binPath;
135
    }
136
137
    public function setVarPath(string $varPath): void
138
    {
139
        if (!is_writable($varPath)) {
140
            throw new InvalidArgumentException(sprintf('The specified path "%s" is not writable', $varPath));
141
        }
142
143
        $this->ensureDirectorySeparator($varPath);
144
        $this->varPath = $varPath;
145
    }
146
147
    public function train(array $samples, array $targets): void
148
    {
149
        $this->samples = array_merge($this->samples, $samples);
150
        $this->targets = array_merge($this->targets, $targets);
151
152
        $trainingSet = DataTransformer::trainingSet($this->samples, $this->targets, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR], true));
153
        file_put_contents($trainingSetFileName = $this->varPath.uniqid('phpml', true), $trainingSet);
154
        $modelFileName = $trainingSetFileName.'-model';
155
156
        $command = $this->buildTrainCommand($trainingSetFileName, $modelFileName);
157
        $output = [];
158
        exec(escapeshellcmd($command).' 2>&1', $output, $return);
159
160
        unlink($trainingSetFileName);
161
162
        if ($return !== 0) {
163
            throw new LibsvmCommandException(
164
                sprintf('Failed running libsvm command: "%s" with reason: "%s"', $command, array_pop($output))
165
            );
166
        }
167
168
        $this->model = file_get_contents($modelFileName);
169
170
        unlink($modelFileName);
171
    }
172
173
    public function getModel(): string
174
    {
175
        return $this->model;
176
    }
177
178
    /**
179
     * @return array|string
180
     *
181
     * @throws LibsvmCommandException
182
     */
183
    public function predict(array $samples)
184
    {
185
        $predictions = $this->runSvmPredict($samples, false);
186
187
        if (in_array($this->type, [Type::C_SVC, Type::NU_SVC], true)) {
188
            $predictions = DataTransformer::predictions($predictions, $this->targets);
189
        } else {
190
            $predictions = explode(PHP_EOL, trim($predictions));
191
        }
192
193
        if (!is_array($samples[0])) {
194
            return $predictions[0];
195
        }
196
197
        return $predictions;
198
    }
199
200
    /**
201
     * @return array|string
202
     *
203
     * @throws LibsvmCommandException
204
     */
205
    public function predictProbability(array $samples)
206
    {
207
        if (!$this->probabilityEstimates) {
208
            throw new InvalidOperationException('Model does not support probabiliy estimates');
209
        }
210
211
        $predictions = $this->runSvmPredict($samples, true);
212
213
        if (in_array($this->type, [Type::C_SVC, Type::NU_SVC], true)) {
214
            $predictions = DataTransformer::probabilities($predictions, $this->targets);
215
        } else {
216
            $predictions = explode(PHP_EOL, trim($predictions));
217
        }
218
219
        if (!is_array($samples[0])) {
220
            return $predictions[0];
221
        }
222
223
        return $predictions;
224
    }
225
226
    private function runSvmPredict(array $samples, bool $probabilityEstimates): string
227
    {
228
        $testSet = DataTransformer::testSet($samples);
229
        file_put_contents($testSetFileName = $this->varPath.uniqid('phpml', true), $testSet);
230
        file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
231
        $outputFileName = $testSetFileName.'-output';
232
233
        $command = $this->buildPredictCommand(
234
            $testSetFileName,
235
            $modelFileName,
236
            $outputFileName,
237
            $probabilityEstimates
238
        );
239
        $output = [];
240
        exec(escapeshellcmd($command).' 2>&1', $output, $return);
241
242
        unlink($testSetFileName);
243
        unlink($modelFileName);
244
        $predictions = file_get_contents($outputFileName);
245
246
        unlink($outputFileName);
247
248
        if ($return !== 0) {
249
            throw new LibsvmCommandException(
250
                sprintf('Failed running libsvm command: "%s" with reason: "%s"', $command, array_pop($output))
251
            );
252
        }
253
254
        return $predictions;
255
    }
256
257
    private function getOSExtension(): string
258
    {
259
        $os = strtoupper(substr(PHP_OS, 0, 3));
260
        if ($os === 'WIN') {
261
            return '.exe';
262
        } elseif ($os === 'DAR') {
263
            return '-osx';
264
        }
265
266
        return '';
267
    }
268
269
    private function buildTrainCommand(string $trainingSetFileName, string $modelFileName): string
270
    {
271
        return sprintf(
272
            '%ssvm-train%s -s %s -t %s -c %s -n %s -d %s%s -r %s -p %s -m %s -e %s -h %d -b %d %s %s',
273
            $this->binPath,
274
            $this->getOSExtension(),
275
            $this->type,
276
            $this->kernel,
277
            $this->cost,
278
            $this->nu,
279
            $this->degree,
280
            $this->gamma !== null ? ' -g '.$this->gamma : '',
281
            $this->coef0,
282
            $this->epsilon,
283
            $this->cacheSize,
284
            $this->tolerance,
285
            $this->shrinking,
286
            $this->probabilityEstimates,
287
            escapeshellarg($trainingSetFileName),
288
            escapeshellarg($modelFileName)
289
        );
290
    }
291
292
    private function buildPredictCommand(
293
        string $testSetFileName,
294
        string $modelFileName,
295
        string $outputFileName,
296
        bool $probabilityEstimates
297
    ): string {
298
        return sprintf(
299
            '%ssvm-predict%s -b %d %s %s %s',
300
            $this->binPath,
301
            $this->getOSExtension(),
302
            $probabilityEstimates ? 1 : 0,
303
            escapeshellarg($testSetFileName),
304
            escapeshellarg($modelFileName),
305
            escapeshellarg($outputFileName)
306
        );
307
    }
308
309
    private function ensureDirectorySeparator(string &$path): void
310
    {
311
        if (substr($path, -1) !== DIRECTORY_SEPARATOR) {
312
            $path .= DIRECTORY_SEPARATOR;
313
        }
314
    }
315
316
    private function verifyBinPath(string $path): void
317
    {
318
        if (!is_dir($path)) {
319
            throw new InvalidArgumentException(sprintf('The specified path "%s" does not exist', $path));
320
        }
321
322
        $osExtension = $this->getOSExtension();
323
        foreach (['svm-predict', 'svm-scale', 'svm-train'] as $filename) {
324
            $filePath = $path.$filename.$osExtension;
325
            if (!file_exists($filePath)) {
326
                throw new InvalidArgumentException(sprintf('File "%s" not found', $filePath));
327
            }
328
329
            if (!is_executable($filePath)) {
330
                throw new InvalidArgumentException(sprintf('File "%s" is not executable', $filePath));
331
            }
332
        }
333
    }
334
}
335