Passed
Pull Request — master (#108)
by
unknown
03:23
created

SupportVectorMachine::setBinPath()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 6
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 6
rs 9.4285
c 0
b 0
f 0
cc 1
eloc 3
nc 1
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\SupportVectorMachine;
6
7
use Phpml\Helper\Trainable;
8
9
class SupportVectorMachine
10
{
11
    use Trainable;
12
13
    /**
14
     * @var int
15
     */
16
    private $type;
17
18
    /**
19
     * @var int
20
     */
21
    private $kernel;
22
23
    /**
24
     * @var float
25
     */
26
    private $cost;
27
28
    /**
29
     * @var float
30
     */
31
    private $nu;
32
33
    /**
34
     * @var int
35
     */
36
    private $degree;
37
38
    /**
39
     * @var float
40
     */
41
    private $gamma;
42
43
    /**
44
     * @var float
45
     */
46
    private $coef0;
47
48
    /**
49
     * @var float
50
     */
51
    private $epsilon;
52
53
    /**
54
     * @var float
55
     */
56
    private $tolerance;
57
58
    /**
59
     * @var int
60
     */
61
    private $cacheSize;
62
63
    /**
64
     * @var bool
65
     */
66
    private $shrinking;
67
68
    /**
69
     * @var bool
70
     */
71
    private $probabilityEstimates;
72
73
    /**
74
     * @var string
75
     */
76
    private $binPath;
77
78
    /**
79
     * @var string
80
     */
81
    private $varPath;
82
83
    /**
84
     * @var string
85
     */
86
    private $model;
87
88
    /**
89
     * @var array
90
     */
91
    private $targets = [];
92
93
    /**
94
     * @param int        $type
95
     * @param int        $kernel
96
     * @param float      $cost
97
     * @param float      $nu
98
     * @param int        $degree
99
     * @param float|null $gamma
100
     * @param float      $coef0
101
     * @param float      $epsilon
102
     * @param float      $tolerance
103
     * @param int        $cacheSize
104
     * @param bool       $shrinking
105
     * @param bool       $probabilityEstimates
106
     */
107
    public function __construct(
108
        int $type,
109
        int $kernel,
110
        float $cost = 1.0,
111
        float $nu = 0.5,
112
        int $degree = 3,
113
        float $gamma = null,
114
        float $coef0 = 0.0,
115
        float $epsilon = 0.1,
116
        float $tolerance = 0.001,
117
        int $cacheSize = 100,
118
        bool $shrinking = true,
119
        bool $probabilityEstimates = false
120
    ) {
121
        $this->type = $type;
122
        $this->kernel = $kernel;
123
        $this->cost = $cost;
124
        $this->nu = $nu;
125
        $this->degree = $degree;
126
        $this->gamma = $gamma;
127
        $this->coef0 = $coef0;
128
        $this->epsilon = $epsilon;
129
        $this->tolerance = $tolerance;
130
        $this->cacheSize = $cacheSize;
131
        $this->shrinking = $shrinking;
132
        $this->probabilityEstimates = $probabilityEstimates;
133
134
        $rootPath = realpath(implode(DIRECTORY_SEPARATOR, [__DIR__, '..', '..', '..'])).DIRECTORY_SEPARATOR;
135
136
        $this->binPath = $rootPath.'bin'.DIRECTORY_SEPARATOR.'libsvm'.DIRECTORY_SEPARATOR;
137
        $this->varPath = $rootPath.'var'.DIRECTORY_SEPARATOR;
138
    }
139
140
    /**
141
     * @param string $binPath
142
     *
143
     * @return $this
144
     */
145
    public function setBinPath(string $binPath)
146
    {
147
        $this->binPath = $binPath;
148
149
        return $this;
150
    }
151
152
    /**
153
     * @param string $varPath
154
     *
155
     * @return $this
156
     */
157
    public function setVarPath(string $varPath)
158
    {
159
        $this->varPath = $varPath;
160
161
        return $this;
162
    }
163
164
    /**
165
     * @param array $samples
166
     * @param array $targets
167
     */
168
    public function train(array $samples, array $targets)
169
    {
170
        $this->samples = array_merge($this->samples, $samples);
171
        $this->targets = array_merge($this->targets, $targets);
172
173
        $trainingSet = DataTransformer::trainingSet($this->samples, $this->targets, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR]));
174
        file_put_contents($trainingSetFileName = $this->varPath.uniqid('phpml', true), $trainingSet);
175
        $modelFileName = $trainingSetFileName.'-model';
176
177
        $command = $this->buildTrainCommand($trainingSetFileName, $modelFileName);
178
        $output = '';
179
        exec(escapeshellcmd($command), $output);
180
181
        $this->model = file_get_contents($modelFileName);
182
183
        unlink($trainingSetFileName);
184
        unlink($modelFileName);
185
    }
186
187
    /**
188
     * @return string
189
     */
190
    public function getModel()
191
    {
192
        return $this->model;
193
    }
194
195
    /**
196
     * @param array $samples
197
     *
198
     * @return array
199
     */
200
    public function predict(array $samples)
201
    {
202
        $testSet = DataTransformer::testSet($samples);
203
        file_put_contents($testSetFileName = $this->varPath.uniqid('phpml', true), $testSet);
204
        file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
205
        $outputFileName = $testSetFileName.'-output';
206
207
        $command = sprintf('%ssvm-predict%s %s %s %s', $this->binPath, $this->getOSExtension(), $testSetFileName, $modelFileName, $outputFileName);
208
        $output = '';
209
        exec(escapeshellcmd($command), $output);
210
211
        $predictions = file_get_contents($outputFileName);
212
213
        unlink($testSetFileName);
214
        unlink($modelFileName);
215
        unlink($outputFileName);
216
217
        if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
218
            $predictions = DataTransformer::predictions($predictions, $this->targets);
219
        } else {
220
            $predictions = explode(PHP_EOL, trim($predictions));
221
        }
222
223
        if (!is_array($samples[0])) {
224
            return $predictions[0];
225
        }
226
227
        return $predictions;
228
    }
229
230
    /**
231
     * @return string
232
     */
233
    private function getOSExtension()
234
    {
235
        $os = strtoupper(substr(PHP_OS, 0, 3));
236
        if ($os === 'WIN') {
237
            return '.exe';
238
        } elseif ($os === 'DAR') {
239
            return '-osx';
240
        }
241
242
        return '';
243
    }
244
245
    /**
246
     * @param string $trainingSetFileName
247
     * @param string $modelFileName
248
     *
249
     * @return string
250
     */
251
    private function buildTrainCommand(string $trainingSetFileName, string $modelFileName): string
252
    {
253
        return sprintf(
254
            '%ssvm-train%s -s %s -t %s -c %s -n %s -d %s%s -r %s -p %s -m %s -e %s -h %d -b %d %s %s',
255
            $this->binPath,
256
            $this->getOSExtension(),
257
            $this->type,
258
            $this->kernel,
259
            $this->cost,
260
            $this->nu,
261
            $this->degree,
262
            $this->gamma !== null ? ' -g '.$this->gamma : '',
263
            $this->coef0,
264
            $this->epsilon,
265
            $this->cacheSize,
266
            $this->tolerance,
267
            $this->shrinking,
268
            $this->probabilityEstimates,
269
            escapeshellarg($trainingSetFileName),
270
            escapeshellarg($modelFileName)
271
        );
272
    }
273
}
274