Completed
Pull Request — master (#400)
by
unknown
03:36
created

SupportVectorMachine::__construct()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 26
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 12
c 0
b 0
f 0
dl 0
loc 26
rs 9.8666
cc 1
nc 1
nop 12

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\SupportVectorMachine;
6
7
use Phpml\Exception\InvalidArgumentException;
8
use Phpml\Exception\InvalidOperationException;
9
use Phpml\Exception\LibsvmCommandException;
10
use Phpml\Helper\Trainable;
11
12
class SupportVectorMachine
13
{
14
    use Trainable;
15
16
    /**
17
     * @var int
18
     */
19
    private $type;
20
21
    /**
22
     * @var int
23
     */
24
    private $kernel;
25
26
    /**
27
     * @var float
28
     */
29
    private $cost;
30
31
    /**
32
     * @var float
33
     */
34
    private $nu;
35
36
    /**
37
     * @var int
38
     */
39
    private $degree;
40
41
    /**
42
     * @var float|null
43
     */
44
    private $gamma;
45
46
    /**
47
     * @var float
48
     */
49
    private $coef0;
50
51
    /**
52
     * @var float
53
     */
54
    private $epsilon;
55
56
    /**
57
     * @var float
58
     */
59
    private $tolerance;
60
61
    /**
62
     * @var int
63
     */
64
    private $cacheSize;
65
66
    /**
67
     * @var bool
68
     */
69
    private $shrinking;
70
71
    /**
72
     * @var bool
73
     */
74
    private $probabilityEstimates;
75
76
    /**
77
     * @var ?string
78
     */
79
    private $binPath = null;
80
81
    /**
82
     * @var ?string
83
     */
84
    private $varPath = null;
85
86
    /**
87
     * @var string
88
     */
89
    private $model;
90
91
    /**
92
     * @var array
93
     */
94
    private $targets = [];
95
96
    public function __construct(
97
        int $type,
98
        int $kernel,
99
        float $cost = 1.0,
100
        float $nu = 0.5,
101
        int $degree = 3,
102
        ?float $gamma = null,
103
        float $coef0 = 0.0,
104
        float $epsilon = 0.1,
105
        float $tolerance = 0.001,
106
        int $cacheSize = 100,
107
        bool $shrinking = true,
108
        bool $probabilityEstimates = false
109
    ) {
110
        $this->type = $type;
111
        $this->kernel = $kernel;
112
        $this->cost = $cost;
113
        $this->nu = $nu;
114
        $this->degree = $degree;
115
        $this->gamma = $gamma;
116
        $this->coef0 = $coef0;
117
        $this->epsilon = $epsilon;
118
        $this->tolerance = $tolerance;
119
        $this->cacheSize = $cacheSize;
120
        $this->shrinking = $shrinking;
121
        $this->probabilityEstimates = $probabilityEstimates;
122
    }
123
124
    /**
125
     * Sets a new bin path for the SVM model.
126
     * If null is provided, the default bin path will be used.
127
     *
128
     * @param string|null $binPath
129
     *
130
     * @throws InvalidArgumentException
131
     */
132
    public function setBinPath(?string $binPath): void
133
    {
134
        if ($binPath !== null) {
135
            $this->ensureDirectorySeparator($binPath);
136
            $this->verifyBinPath($binPath);
137
        }
138
        $this->binPath = $binPath;
139
    }
140
141
    /**
142
     * Sets a new var path for the SVC model.
143
     * If null is provided, the default var path will be used.
144
     *
145
     * @param string|null $varPath
146
     *
147
     * @throws InvalidArgumentException
148
     */
149
    public function setVarPath(?string $varPath): void
150
    {
151
        if ($varPath !== null) {
152
            if (!is_writable($varPath)) {
153
                throw new InvalidArgumentException(sprintf('The specified path "%s" is not writable', $varPath));
154
            }
155
156
            $this->ensureDirectorySeparator($varPath);
157
        }
158
        $this->varPath = $varPath;
159
    }
160
161
    public function train(array $samples, array $targets): void
162
    {
163
        $this->samples = array_merge($this->samples, $samples);
164
        $this->targets = array_merge($this->targets, $targets);
165
166
        $trainingSet = DataTransformer::trainingSet($this->samples, $this->targets, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR], true));
167
        file_put_contents($trainingSetFileName = $this->getVarPath().uniqid('phpml', true), $trainingSet);
168
        $modelFileName = $trainingSetFileName.'-model';
169
170
        $command = $this->buildTrainCommand($trainingSetFileName, $modelFileName);
171
        $output = [];
172
        exec(escapeshellcmd($command).' 2>&1', $output, $return);
173
174
        unlink($trainingSetFileName);
175
176
        if ($return !== 0) {
177
            throw new LibsvmCommandException(
178
                sprintf('Failed running libsvm command: "%s" with reason: "%s"', $command, array_pop($output))
179
            );
180
        }
181
182
        $this->model = (string) file_get_contents($modelFileName);
183
184
        unlink($modelFileName);
185
    }
186
187
    public function getModel(): string
188
    {
189
        return $this->model;
190
    }
191
192
    /**
193
     * @return array|string
194
     *
195
     * @throws LibsvmCommandException
196
     */
197
    public function predict(array $samples)
198
    {
199
        $predictions = $this->runSvmPredict($samples, false);
200
201
        if (in_array($this->type, [Type::C_SVC, Type::NU_SVC], true)) {
202
            $predictions = DataTransformer::predictions($predictions, $this->targets);
203
        } else {
204
            $predictions = explode(PHP_EOL, trim($predictions));
205
        }
206
207
        if (!is_array($samples[0])) {
208
            return $predictions[0];
209
        }
210
211
        return $predictions;
212
    }
213
214
    /**
215
     * @return array|string
216
     *
217
     * @throws LibsvmCommandException
218
     */
219
    public function predictProbability(array $samples)
220
    {
221
        if (!$this->probabilityEstimates) {
222
            throw new InvalidOperationException('Model does not support probabiliy estimates');
223
        }
224
225
        $predictions = $this->runSvmPredict($samples, true);
226
227
        if (in_array($this->type, [Type::C_SVC, Type::NU_SVC], true)) {
228
            $predictions = DataTransformer::probabilities($predictions, $this->targets);
229
        } else {
230
            $predictions = explode(PHP_EOL, trim($predictions));
231
        }
232
233
        if (!is_array($samples[0])) {
234
            return $predictions[0];
235
        }
236
237
        return $predictions;
238
    }
239
240
    protected function getBinPath(): string
241
    {
242
        return is_string($this->binPath)
243
            ? $this->binPath
244
            : static::getRootPath().'bin'.DIRECTORY_SEPARATOR.'libsvm'.DIRECTORY_SEPARATOR;
245
    }
246
247
    protected function getVarPath(): string
248
    {
249
        return is_string($this->varPath)
250
            ? $this->varPath
251
            : static::getRootPath().'var'.DIRECTORY_SEPARATOR;
252
    }
253
254
    protected static function getRootPath(): string
255
    {
256
        return realpath(implode(DIRECTORY_SEPARATOR, [__DIR__, '..', '..'])).DIRECTORY_SEPARATOR;
257
    }
258
259
    private function runSvmPredict(array $samples, bool $probabilityEstimates): string
260
    {
261
        $testSet = DataTransformer::testSet($samples);
262
        file_put_contents($testSetFileName = $this->getVarPath().uniqid('phpml', true), $testSet);
263
        file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
264
        $outputFileName = $testSetFileName.'-output';
265
266
        $command = $this->buildPredictCommand(
267
            $testSetFileName,
268
            $modelFileName,
269
            $outputFileName,
270
            $probabilityEstimates
271
        );
272
        $output = [];
273
        exec(escapeshellcmd($command).' 2>&1', $output, $return);
274
275
        unlink($testSetFileName);
276
        unlink($modelFileName);
277
        $predictions = (string) file_get_contents($outputFileName);
278
279
        unlink($outputFileName);
280
281
        if ($return !== 0) {
282
            throw new LibsvmCommandException(
283
                sprintf('Failed running libsvm command: "%s" with reason: "%s"', $command, array_pop($output))
284
            );
285
        }
286
287
        return $predictions;
288
    }
289
290
    private function getOSExtension(): string
291
    {
292
        $os = strtoupper(substr(PHP_OS, 0, 3));
293
        if ($os === 'WIN') {
294
            return '.exe';
295
        } elseif ($os === 'DAR') {
296
            return '-osx';
297
        }
298
299
        return '';
300
    }
301
302
    private function buildTrainCommand(string $trainingSetFileName, string $modelFileName): string
303
    {
304
        return sprintf(
305
            '%ssvm-train%s -s %s -t %s -c %s -n %F -d %s%s -r %s -p %F -m %F -e %F -h %d -b %d %s %s',
306
            $this->getBinPath(),
307
            $this->getOSExtension(),
308
            $this->type,
309
            $this->kernel,
310
            $this->cost,
311
            $this->nu,
312
            $this->degree,
313
            $this->gamma !== null ? ' -g '.$this->gamma : '',
314
            $this->coef0,
315
            $this->epsilon,
316
            $this->cacheSize,
317
            $this->tolerance,
318
            $this->shrinking,
319
            $this->probabilityEstimates,
320
            escapeshellarg($trainingSetFileName),
321
            escapeshellarg($modelFileName)
322
        );
323
    }
324
325
    private function buildPredictCommand(
326
        string $testSetFileName,
327
        string $modelFileName,
328
        string $outputFileName,
329
        bool $probabilityEstimates
330
    ): string {
331
        return sprintf(
332
            '%ssvm-predict%s -b %d %s %s %s',
333
            $this->getBinPath(),
334
            $this->getOSExtension(),
335
            $probabilityEstimates ? 1 : 0,
336
            escapeshellarg($testSetFileName),
337
            escapeshellarg($modelFileName),
338
            escapeshellarg($outputFileName)
339
        );
340
    }
341
342
    private function ensureDirectorySeparator(string &$path): void
343
    {
344
        if (substr($path, -1) !== DIRECTORY_SEPARATOR) {
345
            $path .= DIRECTORY_SEPARATOR;
346
        }
347
    }
348
349
    private function verifyBinPath(string $path): void
350
    {
351
        if (!is_dir($path)) {
352
            throw new InvalidArgumentException(sprintf('The specified path "%s" does not exist', $path));
353
        }
354
355
        $osExtension = $this->getOSExtension();
356
        foreach (['svm-predict', 'svm-scale', 'svm-train'] as $filename) {
357
            $filePath = $path.$filename.$osExtension;
358
            if (!file_exists($filePath)) {
359
                throw new InvalidArgumentException(sprintf('File "%s" not found', $filePath));
360
            }
361
362
            if (!is_executable($filePath)) {
363
                throw new InvalidArgumentException(sprintf('File "%s" is not executable', $filePath));
364
            }
365
        }
366
    }
367
}
368