Passed
Pull Request — master (#88)
by
unknown
04:18 queued 01:30
created

SupportVectorMachine::getPredictFile()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 4
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
eloc 2
nc 1
nop 0
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\SupportVectorMachine;
6
7
use Phpml\Helper\Trainable;
8
9
class SupportVectorMachine
10
{
11
    use Trainable;
12
13
    /**
14
     * @var int
15
     */
16
    private $type;
17
18
    /**
19
     * @var int
20
     */
21
    private $kernel;
22
23
    /**
24
     * @var float
25
     */
26
    private $cost;
27
28
    /**
29
     * @var float
30
     */
31
    private $nu;
32
33
    /**
34
     * @var int
35
     */
36
    private $degree;
37
38
    /**
39
     * @var float
40
     */
41
    private $gamma;
42
43
    /**
44
     * @var float
45
     */
46
    private $coef0;
47
48
    /**
49
     * @var float
50
     */
51
    private $epsilon;
52
53
    /**
54
     * @var float
55
     */
56
    private $tolerance;
57
58
    /**
59
     * @var int
60
     */
61
    private $cacheSize;
62
63
    /**
64
     * @var bool
65
     */
66
    private $shrinking;
67
68
    /**
69
     * @var bool
70
     */
71
    private $probabilityEstimates;
72
73
    /**
74
     * @var string
75
     */
76
    private $binPath;
77
78
    /**
79
     * @var string
80
     */
81
    private $varPath;
82
83
    /**
84
     * @var string
85
     */
86
    private $model;
87
88
    /**
89
     * @var array
90
     */
91
    private $targets = [];
92
93
    /**
94
     * @param int        $type
95
     * @param int        $kernel
96
     * @param float      $cost
97
     * @param float      $nu
98
     * @param int        $degree
99
     * @param float|null $gamma
100
     * @param float      $coef0
101
     * @param float      $epsilon
102
     * @param float      $tolerance
103
     * @param int        $cacheSize
104
     * @param bool       $shrinking
105
     * @param bool       $probabilityEstimates
106
     */
107
    public function __construct(
108
        int $type, int $kernel, float $cost = 1.0, float $nu = 0.5, int $degree = 3,
109
        float $gamma = null, float $coef0 = 0.0, float $epsilon = 0.1, float $tolerance = 0.001,
110
        int $cacheSize = 100, bool $shrinking = true, bool $probabilityEstimates = false
111
    ) {
112
        $this->type = $type;
113
        $this->kernel = $kernel;
114
        $this->cost = $cost;
115
        $this->nu = $nu;
116
        $this->degree = $degree;
117
        $this->gamma = $gamma;
118
        $this->coef0 = $coef0;
119
        $this->epsilon = $epsilon;
120
        $this->tolerance = $tolerance;
121
        $this->cacheSize = $cacheSize;
122
        $this->shrinking = $shrinking;
123
        $this->probabilityEstimates = $probabilityEstimates;
124
125
        $rootPath = realpath(implode(DIRECTORY_SEPARATOR, [__DIR__, '..', '..', '..'])).DIRECTORY_SEPARATOR;
126
127
        $this->binPath = $rootPath.'bin'.DIRECTORY_SEPARATOR.'libsvm'.DIRECTORY_SEPARATOR;
128
        $this->varPath = $rootPath.'var'.DIRECTORY_SEPARATOR;
129
    }
130
131
    /**
132
     * @param string $binPath
133
     * 
134
     * @throws InvalidArgumentException
135
     */
136
    public function setBinPath(string $binPath)
137
    {
138
        if (!is_dir($binPath)) {
139
            throw new \InvalidArgumentException(sprintf('Bin path "%s" missing.', $binPath));
140
        }
141
142
        $this->binPath = realpath($binPath).DIRECTORY_SEPARATOR;
143
144
        if (!is_executable($this->getTrainFile())) {
145
            throw new \InvalidArgumentException(sprintf('Train file "%s" missing.', $this->getTrainFile()));
146
        }
147
148
        if (!is_executable($this->getPredictFile())) {
149
            throw new \InvalidArgumentException(sprintf('Predict file "%s" missing.', $this->getPredictFile()));
150
        }
151
152
        return $this;
153
    }
154
155
    /**
156
     * @param string $varPath
157
     * 
158
     * @throws InvalidArgumentException
159
     */
160
    public function setVarPath(string $varPath)
161
    {
162
        if (!is_dir($varPath)) {
163
            throw new \InvalidArgumentException(sprintf('Var path "%s" missing.', $varPath));
164
        }
165
166
        $this->varPath = realpath($varPath).DIRECTORY_SEPARATOR;
167
168
        return $this;
169
    }
170
171
    /**
172
     * @param array $samples
173
     * @param array $targets
174
     */
175
    public function train(array $samples, array $targets)
176
    {
177
        $this->samples = array_merge($this->samples, $samples);
178
        $this->targets = array_merge($this->targets, $targets);
179
180
        $trainingSet = DataTransformer::trainingSet($this->samples, $this->targets, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR]));
181
        file_put_contents($trainingSetFileName = $this->varPath.uniqid('phpml', true), $trainingSet);
182
        $modelFileName = $trainingSetFileName.'-model';
183
184
        $command = $this->buildTrainCommand($trainingSetFileName, $modelFileName);
185
        $output = '';
186
        exec(escapeshellcmd($command), $output);
187
188
        $this->model = file_get_contents($modelFileName);
189
190
        unlink($trainingSetFileName);
191
        unlink($modelFileName);
192
    }
193
194
    /**
195
     * @return string
196
     */
197
    public function getModel()
198
    {
199
        return $this->model;
200
    }
201
202
    /**
203
     * @param array $samples
204
     *
205
     * @return array
206
     */
207
    public function predict(array $samples)
208
    {
209
        $testSet = DataTransformer::testSet($samples);
210
        file_put_contents($testSetFileName = $this->varPath.uniqid('phpml', true), $testSet);
211
        file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
212
        $outputFileName = $testSetFileName.'-output';
213
214
        $command = sprintf('%s %s %s %s', $this->getPredictFile(), $testSetFileName, $modelFileName, $outputFileName);
215
        $output = '';
216
        exec(escapeshellcmd($command), $output);
217
218
        $predictions = file_get_contents($outputFileName);
219
220
        unlink($testSetFileName);
221
        unlink($modelFileName);
222
        unlink($outputFileName);
223
224
        if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
225
            $predictions = DataTransformer::predictions($predictions, $this->targets);
226
        } else {
227
            $predictions = explode(PHP_EOL, trim($predictions));
228
        }
229
230
        if (!is_array($samples[0])) {
231
            return $predictions[0];
232
        }
233
234
        return $predictions;
235
    }
236
237
    /**
238
     * @return string
239
     */
240
    private function getOSExtension(): string
241
    {
242
        $os = strtoupper(substr(PHP_OS, 0, 3));
243
        if ($os === 'WIN') {
244
            return '.exe';
245
        } elseif ($os === 'DAR') {
246
            return '-osx';
247
        }
248
249
        return '';
250
    }
251
252
    /**
253
     * @return string
254
     */
255
    private function getTrainFile(): string
256
    {
257
        return sprintf('%ssvm-train%s', $this->binPath, $this->getOSExtension());
258
    }
259
260
    /**
261
     * @return string
262
     */
263
    private function getPredictFile(): string
264
    {
265
        return sprintf('%ssvm-predict%s', $this->binPath, $this->getOSExtension());
266
    }
267
268
    /**
269
     * @param $trainingSetFileName
270
     * @param $modelFileName
271
     *
272
     * @return string
273
     */
274
    private function buildTrainCommand(string $trainingSetFileName, string $modelFileName): string
275
    {
276
        return sprintf('%s -s %s -t %s -c %s -n %s -d %s%s -r %s -p %s -m %s -e %s -h %d -b %d %s %s',
277
            $this->getTrainFile(),
278
            $this->type,
279
            $this->kernel,
280
            $this->cost,
281
            $this->nu,
282
            $this->degree,
283
            $this->gamma !== null ? ' -g '.$this->gamma : '',
284
            $this->coef0,
285
            $this->epsilon,
286
            $this->cacheSize,
287
            $this->tolerance,
288
            $this->shrinking,
289
            $this->probabilityEstimates,
290
            escapeshellarg($trainingSetFileName),
291
            escapeshellarg($modelFileName)
292
        );
293
    }
294
}
295