Completed
Pull Request — master (#400)
by
unknown
03:36
created

SupportVectorMachine::setVarPath()   A

Complexity

Conditions 3
Paths 3

Size

Total Lines 10
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
eloc 5
c 1
b 0
f 0
dl 0
loc 10
rs 10
cc 3
nc 3
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\SupportVectorMachine;
6
7
use Phpml\Exception\InvalidArgumentException;
8
use Phpml\Exception\InvalidOperationException;
9
use Phpml\Exception\LibsvmCommandException;
10
use Phpml\Helper\Trainable;
11
12
class SupportVectorMachine
13
{
14
    use Trainable;
15
16
    /**
17
     * @var int
18
     */
19
    private $type;
20
21
    /**
22
     * @var int
23
     */
24
    private $kernel;
25
26
    /**
27
     * @var float
28
     */
29
    private $cost;
30
31
    /**
32
     * @var float
33
     */
34
    private $nu;
35
36
    /**
37
     * @var int
38
     */
39
    private $degree;
40
41
    /**
42
     * @var float|null
43
     */
44
    private $gamma;
45
46
    /**
47
     * @var float
48
     */
49
    private $coef0;
50
51
    /**
52
     * @var float
53
     */
54
    private $epsilon;
55
56
    /**
57
     * @var float
58
     */
59
    private $tolerance;
60
61
    /**
62
     * @var int
63
     */
64
    private $cacheSize;
65
66
    /**
67
     * @var bool
68
     */
69
    private $shrinking;
70
71
    /**
72
     * @var bool
73
     */
74
    private $probabilityEstimates;
75
76
    /**
77
     * @var ?string
78
     */
79
    private $binPath = null;
80
81
    /**
82
     * @var ?string
83
     */
84
    private $varPath = null;
85
86
    /**
87
     * @var string
88
     */
89
    private $model;
90
91
    /**
92
     * @var array
93
     */
94
    private $targets = [];
95
96
    public function __construct(
97
        int $type,
98
        int $kernel,
99
        float $cost = 1.0,
100
        float $nu = 0.5,
101
        int $degree = 3,
102
        ?float $gamma = null,
103
        float $coef0 = 0.0,
104
        float $epsilon = 0.1,
105
        float $tolerance = 0.001,
106
        int $cacheSize = 100,
107
        bool $shrinking = true,
108
        bool $probabilityEstimates = false
109
    ) {
110
        $this->type = $type;
111
        $this->kernel = $kernel;
112
        $this->cost = $cost;
113
        $this->nu = $nu;
114
        $this->degree = $degree;
115
        $this->gamma = $gamma;
116
        $this->coef0 = $coef0;
117
        $this->epsilon = $epsilon;
118
        $this->tolerance = $tolerance;
119
        $this->cacheSize = $cacheSize;
120
        $this->shrinking = $shrinking;
121
        $this->probabilityEstimates = $probabilityEstimates;
122
    }
123
124
    /**
125
     * Sets a new bin path for the SVM model.
126
     * If null is provided, the default bin path will be used.
127
     *
128
     * @param string|null $binPath
129
     *
130
     * @throws InvalidArgumentException
131
     */
132
    public function setBinPath(?string $binPath): void
133
    {
134
        if ($binPath !== null) {
135
            $this->ensureDirectorySeparator($binPath);
136
            $this->verifyBinPath($binPath);
137
        }
138
        $this->binPath = $binPath;
139
    }
140
141
    /**
142
     * Sets a new var path for the SVC model.
143
     * If null is provided, the default var path will be used.
144
     *
145
     * @param string|null $varPath
146
     *
147
     * @throws InvalidArgumentException
148
     */
149
    public function setVarPath(?string $varPath): void
150
    {
151
        if ($varPath !== null) {
152
            if (!is_writable($varPath)) {
153
                throw new InvalidArgumentException(sprintf('The specified path "%s" is not writable', $varPath));
154
            }
155
156
            $this->ensureDirectorySeparator($varPath);
157
        }
158
        $this->varPath = $varPath;
159
    }
160
161
    public function train(array $samples, array $targets): void
162
    {
163
        $this->samples = array_merge($this->samples, $samples);
164
        $this->targets = array_merge($this->targets, $targets);
165
166
        $trainingSet = DataTransformer::trainingSet($this->samples, $this->targets, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR], true));
167
        file_put_contents($trainingSetFileName = $this->getVarPath().uniqid('phpml', true), $trainingSet);
168
        $modelFileName = $trainingSetFileName.'-model';
169
170
        $command = $this->buildTrainCommand($trainingSetFileName, $modelFileName);
171
        $output = [];
172
        exec(escapeshellcmd($command).' 2>&1', $output, $return);
173
174
        unlink($trainingSetFileName);
175
176
        if ($return !== 0) {
177
            throw new LibsvmCommandException(
178
                sprintf('Failed running libsvm command: "%s" with reason: "%s"', $command, array_pop($output))
179
            );
180
        }
181
182
        $this->model = (string) file_get_contents($modelFileName);
183
184
        unlink($modelFileName);
185
    }
186
187
    public function getModel(): string
188
    {
189
        return $this->model;
190
    }
191
192
    /**
193
     * @return array|string
194
     *
195
     * @throws LibsvmCommandException
196
     */
197
    public function predict(array $samples)
198
    {
199
        $predictions = $this->runSvmPredict($samples, false);
200
201
        if (in_array($this->type, [Type::C_SVC, Type::NU_SVC], true)) {
202
            $predictions = DataTransformer::predictions($predictions, $this->targets);
203
        } else {
204
            $predictions = explode(PHP_EOL, trim($predictions));
205
        }
206
207
        if (!is_array($samples[0])) {
208
            return $predictions[0];
209
        }
210
211
        return $predictions;
212
    }
213
214
    /**
215
     * @return array|string
216
     *
217
     * @throws LibsvmCommandException
218
     */
219
    public function predictProbability(array $samples)
220
    {
221
        if (!$this->probabilityEstimates) {
222
            throw new InvalidOperationException('Model does not support probabiliy estimates');
223
        }
224
225
        $predictions = $this->runSvmPredict($samples, true);
226
227
        if (in_array($this->type, [Type::C_SVC, Type::NU_SVC], true)) {
228
            $predictions = DataTransformer::probabilities($predictions, $this->targets);
229
        } else {
230
            $predictions = explode(PHP_EOL, trim($predictions));
231
        }
232
233
        if (!is_array($samples[0])) {
234
            return $predictions[0];
235
        }
236
237
        return $predictions;
238
    }
239
240
    protected function getBinPath(): string
241
    {
242
        return is_string($this->binPath)
243
            ? $this->binPath
244
            : static::getRootPath().'bin'.DIRECTORY_SEPARATOR.'libsvm'.DIRECTORY_SEPARATOR;
245
    }
246
247
    protected function getVarPath(): string
248
    {
249
        return is_string($this->varPath)
250
            ? $this->varPath
251
            : static::getRootPath().'var'.DIRECTORY_SEPARATOR;
252
    }
253
254
    protected static function getRootPath(): string
255
    {
256
        return realpath(implode(DIRECTORY_SEPARATOR, [__DIR__, '..', '..'])).DIRECTORY_SEPARATOR;
257
    }
258
259
    private function runSvmPredict(array $samples, bool $probabilityEstimates): string
260
    {
261
        $testSet = DataTransformer::testSet($samples);
262
        file_put_contents($testSetFileName = $this->getVarPath().uniqid('phpml', true), $testSet);
263
        file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
264
        $outputFileName = $testSetFileName.'-output';
265
266
        $command = $this->buildPredictCommand(
267
            $testSetFileName,
268
            $modelFileName,
269
            $outputFileName,
270
            $probabilityEstimates
271
        );
272
        $output = [];
273
        exec(escapeshellcmd($command).' 2>&1', $output, $return);
274
275
        unlink($testSetFileName);
276
        unlink($modelFileName);
277
        $predictions = (string) file_get_contents($outputFileName);
278
279
        unlink($outputFileName);
280
281
        if ($return !== 0) {
282
            throw new LibsvmCommandException(
283
                sprintf('Failed running libsvm command: "%s" with reason: "%s"', $command, array_pop($output))
284
            );
285
        }
286
287
        return $predictions;
288
    }
289
290
    private function getOSExtension(): string
291
    {
292
        $os = strtoupper(substr(PHP_OS, 0, 3));
293
        if ($os === 'WIN') {
294
            return '.exe';
295
        } elseif ($os === 'DAR') {
296
            return '-osx';
297
        }
298
299
        return '';
300
    }
301
302
    private function buildTrainCommand(string $trainingSetFileName, string $modelFileName): string
303
    {
304
        return sprintf(
305
            '%ssvm-train%s -s %s -t %s -c %s -n %F -d %s%s -r %s -p %F -m %F -e %F -h %d -b %d %s %s',
306
            $this->getBinPath(),
307
            $this->getOSExtension(),
308
            $this->type,
309
            $this->kernel,
310
            $this->cost,
311
            $this->nu,
312
            $this->degree,
313
            $this->gamma !== null ? ' -g '.$this->gamma : '',
314
            $this->coef0,
315
            $this->epsilon,
316
            $this->cacheSize,
317
            $this->tolerance,
318
            $this->shrinking,
319
            $this->probabilityEstimates,
320
            escapeshellarg($trainingSetFileName),
321
            escapeshellarg($modelFileName)
322
        );
323
    }
324
325
    private function buildPredictCommand(
326
        string $testSetFileName,
327
        string $modelFileName,
328
        string $outputFileName,
329
        bool $probabilityEstimates
330
    ): string {
331
        return sprintf(
332
            '%ssvm-predict%s -b %d %s %s %s',
333
            $this->getBinPath(),
334
            $this->getOSExtension(),
335
            $probabilityEstimates ? 1 : 0,
336
            escapeshellarg($testSetFileName),
337
            escapeshellarg($modelFileName),
338
            escapeshellarg($outputFileName)
339
        );
340
    }
341
342
    private function ensureDirectorySeparator(string &$path): void
343
    {
344
        if (substr($path, -1) !== DIRECTORY_SEPARATOR) {
345
            $path .= DIRECTORY_SEPARATOR;
346
        }
347
    }
348
349
    private function verifyBinPath(string $path): void
350
    {
351
        if (!is_dir($path)) {
352
            throw new InvalidArgumentException(sprintf('The specified path "%s" does not exist', $path));
353
        }
354
355
        $osExtension = $this->getOSExtension();
356
        foreach (['svm-predict', 'svm-scale', 'svm-train'] as $filename) {
357
            $filePath = $path.$filename.$osExtension;
358
            if (!file_exists($filePath)) {
359
                throw new InvalidArgumentException(sprintf('File "%s" not found', $filePath));
360
            }
361
362
            if (!is_executable($filePath)) {
363
                throw new InvalidArgumentException(sprintf('File "%s" is not executable', $filePath));
364
            }
365
        }
366
    }
367
}
368