Completed
Push — master ( ba2b8c...61d2b7 )
by Arkadiusz
07:09
created

SupportVectorMachine::ensureDirectorySeparator()   A

Complexity

Conditions 2
Paths 2

Size

Total Lines 6
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 6
rs 9.4285
c 0
b 0
f 0
cc 2
eloc 3
nc 2
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\SupportVectorMachine;
6
7
use Phpml\Exception\InvalidArgumentException;
8
use Phpml\Helper\Trainable;
9
10
class SupportVectorMachine
11
{
12
    use Trainable;
13
14
    /**
15
     * @var int
16
     */
17
    private $type;
18
19
    /**
20
     * @var int
21
     */
22
    private $kernel;
23
24
    /**
25
     * @var float
26
     */
27
    private $cost;
28
29
    /**
30
     * @var float
31
     */
32
    private $nu;
33
34
    /**
35
     * @var int
36
     */
37
    private $degree;
38
39
    /**
40
     * @var float
41
     */
42
    private $gamma;
43
44
    /**
45
     * @var float
46
     */
47
    private $coef0;
48
49
    /**
50
     * @var float
51
     */
52
    private $epsilon;
53
54
    /**
55
     * @var float
56
     */
57
    private $tolerance;
58
59
    /**
60
     * @var int
61
     */
62
    private $cacheSize;
63
64
    /**
65
     * @var bool
66
     */
67
    private $shrinking;
68
69
    /**
70
     * @var bool
71
     */
72
    private $probabilityEstimates;
73
74
    /**
75
     * @var string
76
     */
77
    private $binPath;
78
79
    /**
80
     * @var string
81
     */
82
    private $varPath;
83
84
    /**
85
     * @var string
86
     */
87
    private $model;
88
89
    /**
90
     * @var array
91
     */
92
    private $targets = [];
93
94
    /**
95
     * @param int        $type
96
     * @param int        $kernel
97
     * @param float      $cost
98
     * @param float      $nu
99
     * @param int        $degree
100
     * @param float|null $gamma
101
     * @param float      $coef0
102
     * @param float      $epsilon
103
     * @param float      $tolerance
104
     * @param int        $cacheSize
105
     * @param bool       $shrinking
106
     * @param bool       $probabilityEstimates
107
     */
108
    public function __construct(
109
        int $type,
110
        int $kernel,
111
        float $cost = 1.0,
112
        float $nu = 0.5,
113
        int $degree = 3,
114
        float $gamma = null,
115
        float $coef0 = 0.0,
116
        float $epsilon = 0.1,
117
        float $tolerance = 0.001,
118
        int $cacheSize = 100,
119
        bool $shrinking = true,
120
        bool $probabilityEstimates = false
121
    ) {
122
        $this->type = $type;
123
        $this->kernel = $kernel;
124
        $this->cost = $cost;
125
        $this->nu = $nu;
126
        $this->degree = $degree;
127
        $this->gamma = $gamma;
128
        $this->coef0 = $coef0;
129
        $this->epsilon = $epsilon;
130
        $this->tolerance = $tolerance;
131
        $this->cacheSize = $cacheSize;
132
        $this->shrinking = $shrinking;
133
        $this->probabilityEstimates = $probabilityEstimates;
134
135
        $rootPath = realpath(implode(DIRECTORY_SEPARATOR, [__DIR__, '..', '..', '..'])).DIRECTORY_SEPARATOR;
136
137
        $this->binPath = $rootPath.'bin'.DIRECTORY_SEPARATOR.'libsvm'.DIRECTORY_SEPARATOR;
138
        $this->varPath = $rootPath.'var'.DIRECTORY_SEPARATOR;
139
    }
140
141
    /**
142
     * @param string $binPath
143
     *
144
     * @throws InvalidArgumentException
145
     */
146
    public function setBinPath(string $binPath)
147
    {
148
        $this->ensureDirectorySeparator($binPath);
149
        $this->verifyBinPath($binPath);
150
151
        $this->binPath = $binPath;
152
    }
153
154
    /**
155
     * @param string $varPath
156
     *
157
     * @throws InvalidArgumentException
158
     */
159
    public function setVarPath(string $varPath)
160
    {
161
        if (!is_writable($varPath)) {
162
            throw InvalidArgumentException::pathNotWritable($varPath);
163
        }
164
165
        $this->ensureDirectorySeparator($varPath);
166
        $this->varPath = $varPath;
167
    }
168
169
    /**
170
     * @param array $samples
171
     * @param array $targets
172
     */
173
    public function train(array $samples, array $targets)
174
    {
175
        $this->samples = array_merge($this->samples, $samples);
176
        $this->targets = array_merge($this->targets, $targets);
177
178
        $trainingSet = DataTransformer::trainingSet($this->samples, $this->targets, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR]));
179
        file_put_contents($trainingSetFileName = $this->varPath.uniqid('phpml', true), $trainingSet);
180
        $modelFileName = $trainingSetFileName.'-model';
181
182
        $command = $this->buildTrainCommand($trainingSetFileName, $modelFileName);
183
        $output = '';
184
        exec(escapeshellcmd($command), $output);
185
186
        $this->model = file_get_contents($modelFileName);
187
188
        unlink($trainingSetFileName);
189
        unlink($modelFileName);
190
    }
191
192
    /**
193
     * @return string
194
     */
195
    public function getModel()
196
    {
197
        return $this->model;
198
    }
199
200
    /**
201
     * @param array $samples
202
     *
203
     * @return array
204
     */
205
    public function predict(array $samples)
206
    {
207
        $testSet = DataTransformer::testSet($samples);
208
        file_put_contents($testSetFileName = $this->varPath.uniqid('phpml', true), $testSet);
209
        file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
210
        $outputFileName = $testSetFileName.'-output';
211
212
        $command = sprintf('%ssvm-predict%s %s %s %s', $this->binPath, $this->getOSExtension(), $testSetFileName, $modelFileName, $outputFileName);
213
        $output = '';
214
        exec(escapeshellcmd($command), $output);
215
216
        $predictions = file_get_contents($outputFileName);
217
218
        unlink($testSetFileName);
219
        unlink($modelFileName);
220
        unlink($outputFileName);
221
222
        if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
223
            $predictions = DataTransformer::predictions($predictions, $this->targets);
224
        } else {
225
            $predictions = explode(PHP_EOL, trim($predictions));
226
        }
227
228
        if (!is_array($samples[0])) {
229
            return $predictions[0];
230
        }
231
232
        return $predictions;
233
    }
234
235
    /**
236
     * @return string
237
     */
238
    private function getOSExtension()
239
    {
240
        $os = strtoupper(substr(PHP_OS, 0, 3));
241
        if ($os === 'WIN') {
242
            return '.exe';
243
        } elseif ($os === 'DAR') {
244
            return '-osx';
245
        }
246
247
        return '';
248
    }
249
250
    /**
251
     * @param string $trainingSetFileName
252
     * @param string $modelFileName
253
     *
254
     * @return string
255
     */
256
    private function buildTrainCommand(string $trainingSetFileName, string $modelFileName): string
257
    {
258
        return sprintf(
259
            '%ssvm-train%s -s %s -t %s -c %s -n %s -d %s%s -r %s -p %s -m %s -e %s -h %d -b %d %s %s',
260
            $this->binPath,
261
            $this->getOSExtension(),
262
            $this->type,
263
            $this->kernel,
264
            $this->cost,
265
            $this->nu,
266
            $this->degree,
267
            $this->gamma !== null ? ' -g '.$this->gamma : '',
268
            $this->coef0,
269
            $this->epsilon,
270
            $this->cacheSize,
271
            $this->tolerance,
272
            $this->shrinking,
273
            $this->probabilityEstimates,
274
            escapeshellarg($trainingSetFileName),
275
            escapeshellarg($modelFileName)
276
        );
277
    }
278
279
    /**
280
     * @param string $path
281
     */
282
    private function ensureDirectorySeparator(string &$path)
283
    {
284
        if (substr($path, -1) !== DIRECTORY_SEPARATOR) {
285
            $path .= DIRECTORY_SEPARATOR;
286
        }
287
    }
288
289
    /**
290
     * @param string $path
291
     *
292
     * @throws InvalidArgumentException
293
     */
294
    private function verifyBinPath(string $path)
295
    {
296
        if (!is_dir($path)) {
297
            throw InvalidArgumentException::pathNotFound($path);
298
        }
299
300
        $osExtension = $this->getOSExtension();
301
        foreach (['svm-predict', 'svm-scale', 'svm-train'] as $filename) {
302
            $filePath = $path.$filename.$osExtension;
303
            if (!file_exists($filePath)) {
304
                throw InvalidArgumentException::fileNotFound($filePath);
305
            }
306
307
            if (!is_executable($filePath)) {
308
                throw InvalidArgumentException::fileNotExecutable($filePath);
309
            }
310
        }
311
    }
312
}
313