Passed
Pull Request — master (#88)
by
unknown
02:53
created

SupportVectorMachine::setBinPath()   B

Complexity

Conditions 4
Paths 4

Size

Total Lines 23
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 23
rs 8.7972
c 0
b 0
f 0
cc 4
eloc 12
nc 4
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\SupportVectorMachine;
6
7
use Phpml\Helper\Trainable;
8
use Phpml\Exception\SupportVectorMachineException;
9
10
class SupportVectorMachine
11
{
12
    use Trainable;
13
14
    /**
15
     * @var int
16
     */
17
    private $type;
18
19
    /**
20
     * @var int
21
     */
22
    private $kernel;
23
24
    /**
25
     * @var float
26
     */
27
    private $cost;
28
29
    /**
30
     * @var float
31
     */
32
    private $nu;
33
34
    /**
35
     * @var int
36
     */
37
    private $degree;
38
39
    /**
40
     * @var float
41
     */
42
    private $gamma;
43
44
    /**
45
     * @var float
46
     */
47
    private $coef0;
48
49
    /**
50
     * @var float
51
     */
52
    private $epsilon;
53
54
    /**
55
     * @var float
56
     */
57
    private $tolerance;
58
59
    /**
60
     * @var int
61
     */
62
    private $cacheSize;
63
64
    /**
65
     * @var bool
66
     */
67
    private $shrinking;
68
69
    /**
70
     * @var bool
71
     */
72
    private $probabilityEstimates;
73
74
    /**
75
     * @var string
76
     */
77
    private $binPath;
78
79
    /**
80
     * @var string
81
     */
82
    private $varPath;
83
84
    /**
85
     * @var string
86
     */
87
    private $model;
88
89
    /**
90
     * @var array
91
     */
92
    private $targets = [];
93
94
    /**
95
     * @param int        $type
96
     * @param int        $kernel
97
     * @param float      $cost
98
     * @param float      $nu
99
     * @param int        $degree
100
     * @param float|null $gamma
101
     * @param float      $coef0
102
     * @param float      $epsilon
103
     * @param float      $tolerance
104
     * @param int        $cacheSize
105
     * @param bool       $shrinking
106
     * @param bool       $probabilityEstimates
107
     */
108
    public function __construct(
109
        int $type, int $kernel, float $cost = 1.0, float $nu = 0.5, int $degree = 3,
110
        float $gamma = null, float $coef0 = 0.0, float $epsilon = 0.1, float $tolerance = 0.001,
111
        int $cacheSize = 100, bool $shrinking = true, bool $probabilityEstimates = false
112
    ) {
113
        $this->type = $type;
114
        $this->kernel = $kernel;
115
        $this->cost = $cost;
116
        $this->nu = $nu;
117
        $this->degree = $degree;
118
        $this->gamma = $gamma;
119
        $this->coef0 = $coef0;
120
        $this->epsilon = $epsilon;
121
        $this->tolerance = $tolerance;
122
        $this->cacheSize = $cacheSize;
123
        $this->shrinking = $shrinking;
124
        $this->probabilityEstimates = $probabilityEstimates;
125
126
        $rootPath = realpath(implode(DIRECTORY_SEPARATOR, [__DIR__, '..', '..', '..'])).DIRECTORY_SEPARATOR;
127
128
        $this->binPath = $rootPath.'bin'.DIRECTORY_SEPARATOR.'libsvm'.DIRECTORY_SEPARATOR;
129
        $this->varPath = $rootPath.'var'.DIRECTORY_SEPARATOR;
130
    }
131
132
    /**
133
     * @param string $binPath
134
     * 
135
     * @throws SupportVectorMachineException
136
     */
137
    public function setBinPath(string $binPath)
138
    {
139
        if (!is_dir($binPath)) {
140
            throw SupportVectorMachineException::missingBinPath($binPath);
141
        }
142
143
        $binPath = realpath($binPath).DIRECTORY_SEPARATOR;
144
145
        $trainFile = sprintf('%ssvm-train%s', $binPath, $this->getOSExtension());
146
        $predictFile = sprintf('%ssvm-predict%s', $binPath, $this->getOSExtension());
147
148
        if (!is_executable($trainFile)) {
149
            throw SupportVectorMachineException::missingTrainFile($trainFile);
150
        }
151
152
        if (!is_executable($predictFile)) {
153
            throw SupportVectorMachineException::missingPredictFile($predictFile);
154
        }
155
156
        $this->binPath = $binPath;
157
158
        return $this;
159
    }
160
161
    /**
162
     * @param string $varPath
163
     * 
164
     * @throws SupportVectorMachineException
165
     */
166
    public function setVarPath(string $varPath)
167
    {
168
        if (!is_dir($varPath)) {
169
            throw SupportVectorMachineException::missingVarPath($varPath);
170
        }
171
172
        $varPath = realpath($varPath).DIRECTORY_SEPARATOR;
173
174
        $this->varPath = $varPath;
175
176
        return $this;
177
    }
178
179
    /**
180
     * @param array $samples
181
     * @param array $targets
182
     */
183
    public function train(array $samples, array $targets)
184
    {
185
        $this->samples = array_merge($this->samples, $samples);
186
        $this->targets = array_merge($this->targets, $targets);
187
188
        $trainingSet = DataTransformer::trainingSet($this->samples, $this->targets, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR]));
189
        file_put_contents($trainingSetFileName = $this->varPath.uniqid('phpml', true), $trainingSet);
190
        $modelFileName = $trainingSetFileName.'-model';
191
192
        $command = $this->buildTrainCommand($trainingSetFileName, $modelFileName);
193
        $output = '';
194
        exec(escapeshellcmd($command), $output);
195
196
        $this->model = file_get_contents($modelFileName);
197
198
        unlink($trainingSetFileName);
199
        unlink($modelFileName);
200
    }
201
202
    /**
203
     * @return string
204
     */
205
    public function getModel()
206
    {
207
        return $this->model;
208
    }
209
210
    /**
211
     * @param array $samples
212
     *
213
     * @return array
214
     */
215
    public function predict(array $samples)
216
    {
217
        $testSet = DataTransformer::testSet($samples);
218
        file_put_contents($testSetFileName = $this->varPath.uniqid('phpml', true), $testSet);
219
        file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
220
        $outputFileName = $testSetFileName.'-output';
221
222
        $command = sprintf('%ssvm-predict%s %s %s %s', $this->binPath, $this->getOSExtension(), $testSetFileName, $modelFileName, $outputFileName);
223
        $output = '';
224
        exec(escapeshellcmd($command), $output);
225
226
        $predictions = file_get_contents($outputFileName);
227
228
        unlink($testSetFileName);
229
        unlink($modelFileName);
230
        unlink($outputFileName);
231
232
        if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
233
            $predictions = DataTransformer::predictions($predictions, $this->targets);
234
        } else {
235
            $predictions = explode(PHP_EOL, trim($predictions));
236
        }
237
238
        if (!is_array($samples[0])) {
239
            return $predictions[0];
240
        }
241
242
        return $predictions;
243
    }
244
245
    /**
246
     * @return string
247
     */
248
    private function getOSExtension()
249
    {
250
        $os = strtoupper(substr(PHP_OS, 0, 3));
251
        if ($os === 'WIN') {
252
            return '.exe';
253
        } elseif ($os === 'DAR') {
254
            return '-osx';
255
        }
256
257
        return '';
258
    }
259
260
    /**
261
     * @param $trainingSetFileName
262
     * @param $modelFileName
263
     *
264
     * @return string
265
     */
266
    private function buildTrainCommand(string $trainingSetFileName, string $modelFileName): string
267
    {
268
        return sprintf('%ssvm-train%s -s %s -t %s -c %s -n %s -d %s%s -r %s -p %s -m %s -e %s -h %d -b %d %s %s',
269
            $this->binPath,
270
            $this->getOSExtension(),
271
            $this->type,
272
            $this->kernel,
273
            $this->cost,
274
            $this->nu,
275
            $this->degree,
276
            $this->gamma !== null ? ' -g '.$this->gamma : '',
277
            $this->coef0,
278
            $this->epsilon,
279
            $this->cacheSize,
280
            $this->tolerance,
281
            $this->shrinking,
282
            $this->probabilityEstimates,
283
            escapeshellarg($trainingSetFileName),
284
            escapeshellarg($modelFileName)
285
        );
286
    }
287
}
288