Passed
Pull Request — master (#250)
by Yuji
02:40
created

SupportVectorMachine::__construct()   B

Complexity

Conditions 1
Paths 1

Size

Total Lines 32
Code Lines 28

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 32
rs 8.8571
c 0
b 0
f 0
cc 1
eloc 28
nc 1
nop 12

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\SupportVectorMachine;
6
7
use Phpml\Exception\InvalidArgumentException;
8
use Phpml\Exception\InvalidOperationException;
9
use Phpml\Exception\LibsvmCommandException;
10
use Phpml\Helper\Trainable;
11
12
class SupportVectorMachine
13
{
14
    use Trainable;
15
16
    /**
17
     * @var int
18
     */
19
    private $type;
20
21
    /**
22
     * @var int
23
     */
24
    private $kernel;
25
26
    /**
27
     * @var float
28
     */
29
    private $cost;
30
31
    /**
32
     * @var float
33
     */
34
    private $nu;
35
36
    /**
37
     * @var int
38
     */
39
    private $degree;
40
41
    /**
42
     * @var float|null
43
     */
44
    private $gamma;
45
46
    /**
47
     * @var float
48
     */
49
    private $coef0;
50
51
    /**
52
     * @var float
53
     */
54
    private $epsilon;
55
56
    /**
57
     * @var float
58
     */
59
    private $tolerance;
60
61
    /**
62
     * @var int
63
     */
64
    private $cacheSize;
65
66
    /**
67
     * @var bool
68
     */
69
    private $shrinking;
70
71
    /**
72
     * @var bool
73
     */
74
    private $probabilityEstimates;
75
76
    /**
77
     * @var string
78
     */
79
    private $binPath;
80
81
    /**
82
     * @var string
83
     */
84
    private $varPath;
85
86
    /**
87
     * @var string
88
     */
89
    private $model;
90
91
    /**
92
     * @var array
93
     */
94
    private $targets = [];
95
96
    public function __construct(
97
        int $type,
98
        int $kernel,
99
        float $cost = 1.0,
100
        float $nu = 0.5,
101
        int $degree = 3,
102
        ?float $gamma = null,
103
        float $coef0 = 0.0,
104
        float $epsilon = 0.1,
105
        float $tolerance = 0.001,
106
        int $cacheSize = 100,
107
        bool $shrinking = true,
108
        bool $probabilityEstimates = false
109
    ) {
110
        $this->type = $type;
111
        $this->kernel = $kernel;
112
        $this->cost = $cost;
113
        $this->nu = $nu;
114
        $this->degree = $degree;
115
        $this->gamma = $gamma;
116
        $this->coef0 = $coef0;
117
        $this->epsilon = $epsilon;
118
        $this->tolerance = $tolerance;
119
        $this->cacheSize = $cacheSize;
120
        $this->shrinking = $shrinking;
121
        $this->probabilityEstimates = $probabilityEstimates;
122
123
        $rootPath = realpath(implode(DIRECTORY_SEPARATOR, [__DIR__, '..', '..'])).DIRECTORY_SEPARATOR;
124
125
        $this->binPath = $rootPath.'bin'.DIRECTORY_SEPARATOR.'libsvm'.DIRECTORY_SEPARATOR;
126
        $this->varPath = $rootPath.'var'.DIRECTORY_SEPARATOR;
127
    }
128
129
    public function setBinPath(string $binPath): void
130
    {
131
        $this->ensureDirectorySeparator($binPath);
132
        $this->verifyBinPath($binPath);
133
134
        $this->binPath = $binPath;
135
    }
136
137
    public function setVarPath(string $varPath): void
138
    {
139
        if (!is_writable($varPath)) {
140
            throw new InvalidArgumentException(sprintf('The specified path "%s" is not writable', $varPath));
141
        }
142
143
        $this->ensureDirectorySeparator($varPath);
144
        $this->varPath = $varPath;
145
    }
146
147
    public function train(array $samples, array $targets): void
148
    {
149
        $this->samples = array_merge($this->samples, $samples);
150
        $this->targets = array_merge($this->targets, $targets);
151
152
        $trainingSet = DataTransformer::trainingSet($this->samples, $this->targets, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR], true));
153
        file_put_contents($trainingSetFileName = $this->varPath.uniqid('phpml', true), $trainingSet);
154
        $modelFileName = $trainingSetFileName.'-model';
155
156
        $command = $this->buildTrainCommand($trainingSetFileName, $modelFileName);
157
        $output = [];
158
        exec(escapeshellcmd($command).' 2>&1', $output, $return);
159
160
        unlink($trainingSetFileName);
161
162
        if ($return !== 0) {
163
            throw new LibsvmCommandException(
164
                sprintf('Failed running libsvm command: "%s" with reason: "%s"', $command, array_pop($output))
165
            );
166
        }
167
168
        $this->model = file_get_contents($modelFileName);
169
170
        unlink($modelFileName);
171
    }
172
173
    public function getModel(): string
174
    {
175
        return $this->model;
176
    }
177
178
    /**
179
     * @return array|string
180
     *
181
     * @throws LibsvmCommandException
182
     */
183
    public function predict(array $samples)
184
    {
185
        $predictions = $this->runSvmPredict($samples, false);
186
187
        if (in_array($this->type, [Type::C_SVC, Type::NU_SVC], true)) {
188
            $predictions = DataTransformer::predictions($predictions, $this->targets);
189
        } else {
190
            $predictions = explode(PHP_EOL, trim($predictions));
191
        }
192
193
        if (!is_array($samples[0])) {
194
            return $predictions[0];
195
        }
196
197
        return $predictions;
198
    }
199
200
    /**
201
     * @return array|string
202
     *
203
     * @throws LibsvmCommandException
204
     */
205
    public function predictProbability(array $samples)
206
    {
207
        if (!$this->probabilityEstimates) {
208
            throw new InvalidOperationException('Model does not support probabiliy estimates');
209
        }
210
211
        $predictions = $this->runSvmPredict($samples, true);
212
213
        if (in_array($this->type, [Type::C_SVC, Type::NU_SVC], true)) {
214
            $predictions = DataTransformer::probabilities($predictions, $this->targets);
215
        } else {
216
            $predictions = explode(PHP_EOL, trim($predictions));
217
        }
218
219
        if (!is_array($samples[0])) {
220
            return $predictions[0];
221
        }
222
223
        return $predictions;
224
    }
225
226
    private function runSvmPredict(array $samples, bool $probabilityEstimates): string
227
    {
228
        $testSet = DataTransformer::testSet($samples);
229
        file_put_contents($testSetFileName = $this->varPath.uniqid('phpml', true), $testSet);
230
        file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
231
        $outputFileName = $testSetFileName.'-output';
232
233
        $command = $this->buildPredictCommand(
234
            $testSetFileName,
235
            $modelFileName,
236
            $outputFileName,
237
            $probabilityEstimates
238
        );
239
        $output = [];
240
        exec(escapeshellcmd($command).' 2>&1', $output, $return);
241
242
        unlink($testSetFileName);
243
        unlink($modelFileName);
244
        $predictions = file_get_contents($outputFileName);
245
246
        unlink($outputFileName);
247
248
        if ($return !== 0) {
249
            throw new LibsvmCommandException(
250
                sprintf('Failed running libsvm command: "%s" with reason: "%s"', $command, array_pop($output))
251
            );
252
        }
253
254
        return $predictions;
255
    }
256
257
    private function getOSExtension(): string
258
    {
259
        $os = strtoupper(substr(PHP_OS, 0, 3));
260
        if ($os === 'WIN') {
261
            return '.exe';
262
        } elseif ($os === 'DAR') {
263
            return '-osx';
264
        }
265
266
        return '';
267
    }
268
269
    private function buildTrainCommand(string $trainingSetFileName, string $modelFileName): string
270
    {
271
        return sprintf(
272
            '%ssvm-train%s -s %s -t %s -c %s -n %s -d %s%s -r %s -p %s -m %s -e %s -h %d -b %d %s %s',
273
            $this->binPath,
274
            $this->getOSExtension(),
275
            $this->type,
276
            $this->kernel,
277
            $this->cost,
278
            $this->nu,
279
            $this->degree,
280
            $this->gamma !== null ? ' -g '.$this->gamma : '',
281
            $this->coef0,
282
            $this->epsilon,
283
            $this->cacheSize,
284
            $this->tolerance,
285
            $this->shrinking,
286
            $this->probabilityEstimates,
287
            escapeshellarg($trainingSetFileName),
288
            escapeshellarg($modelFileName)
289
        );
290
    }
291
292
    private function buildPredictCommand(
293
        string $testSetFileName,
294
        string $modelFileName,
295
        string $outputFileName,
296
        bool $probabilityEstimates
297
    ): string {
298
        return sprintf(
299
            '%ssvm-predict%s -b %d %s %s %s',
300
            $this->binPath,
301
            $this->getOSExtension(),
302
            $probabilityEstimates ? 1 : 0,
303
            escapeshellarg($testSetFileName),
304
            escapeshellarg($modelFileName),
305
            escapeshellarg($outputFileName)
306
        );
307
    }
308
309
    private function ensureDirectorySeparator(string &$path): void
310
    {
311
        if (substr($path, -1) !== DIRECTORY_SEPARATOR) {
312
            $path .= DIRECTORY_SEPARATOR;
313
        }
314
    }
315
316
    private function verifyBinPath(string $path): void
317
    {
318
        if (!is_dir($path)) {
319
            throw new InvalidArgumentException(sprintf('The specified path "%s" does not exist', $path));
320
        }
321
322
        $osExtension = $this->getOSExtension();
323
        foreach (['svm-predict', 'svm-scale', 'svm-train'] as $filename) {
324
            $filePath = $path.$filename.$osExtension;
325
            if (!file_exists($filePath)) {
326
                throw new InvalidArgumentException(sprintf('File "%s" not found', $filePath));
327
            }
328
329
            if (!is_executable($filePath)) {
330
                throw new InvalidArgumentException(sprintf('File "%s" is not executable', $filePath));
331
            }
332
        }
333
    }
334
}
335