Passed
Pull Request — master (#73)
by Arkadiusz
03:46 queued 57s
created

SupportVectorMachine::setVarPath()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 6
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 6
rs 9.4285
c 0
b 0
f 0
cc 1
eloc 3
nc 1
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\SupportVectorMachine;
6
7
use Phpml\Helper\Trainable;
8
9
class SupportVectorMachine
10
{
11
    use Trainable;
12
13
    /**
14
     * @var int
15
     */
16
    private $type;
17
18
    /**
19
     * @var int
20
     */
21
    private $kernel;
22
23
    /**
24
     * @var float
25
     */
26
    private $cost;
27
28
    /**
29
     * @var float
30
     */
31
    private $nu;
32
33
    /**
34
     * @var int
35
     */
36
    private $degree;
37
38
    /**
39
     * @var float
40
     */
41
    private $gamma;
42
43
    /**
44
     * @var float
45
     */
46
    private $coef0;
47
48
    /**
49
     * @var float
50
     */
51
    private $epsilon;
52
53
    /**
54
     * @var float
55
     */
56
    private $tolerance;
57
58
    /**
59
     * @var int
60
     */
61
    private $cacheSize;
62
63
    /**
64
     * @var bool
65
     */
66
    private $shrinking;
67
68
    /**
69
     * @var bool
70
     */
71
    private $probabilityEstimates;
72
73
    /**
74
     * @var string
75
     */
76
    private $binPath;
77
78
    /**
79
     * @var string
80
     */
81
    private $varPath;
82
83
    /**
84
     * @var string
85
     */
86
    private $model;
87
88
    /**
89
     * @var array
90
     */
91
    private $targets = [];
92
93
    /**
94
     * @param int        $type
95
     * @param int        $kernel
96
     * @param float      $cost
97
     * @param float      $nu
98
     * @param int        $degree
99
     * @param float|null $gamma
100
     * @param float      $coef0
101
     * @param float      $epsilon
102
     * @param float      $tolerance
103
     * @param int        $cacheSize
104
     * @param bool       $shrinking
105
     * @param bool       $probabilityEstimates
106
     */
107
    public function __construct(
108
        int $type, int $kernel, float $cost = 1.0, float $nu = 0.5, int $degree = 3,
109
        float $gamma = null, float $coef0 = 0.0, float $epsilon = 0.1, float $tolerance = 0.001,
110
        int $cacheSize = 100, bool $shrinking = true, bool $probabilityEstimates = false
111
    ) {
112
        $this->type = $type;
113
        $this->kernel = $kernel;
114
        $this->cost = $cost;
115
        $this->nu = $nu;
116
        $this->degree = $degree;
117
        $this->gamma = $gamma;
118
        $this->coef0 = $coef0;
119
        $this->epsilon = $epsilon;
120
        $this->tolerance = $tolerance;
121
        $this->cacheSize = $cacheSize;
122
        $this->shrinking = $shrinking;
123
        $this->probabilityEstimates = $probabilityEstimates;
124
125
        $rootPath = realpath(implode(DIRECTORY_SEPARATOR, [__DIR__, '..', '..', '..'])).DIRECTORY_SEPARATOR;
126
127
        $this->binPath = $rootPath.'bin'.DIRECTORY_SEPARATOR.'libsvm'.DIRECTORY_SEPARATOR;
128
        $this->varPath = $rootPath.'var'.DIRECTORY_SEPARATOR;
129
    }
130
131
    /**
132
     * @param string $binPath
133
     */
134
    public function setBinPath(string $binPath)
135
    {
136
        $this->binPath = $binPath;
137
138
        return $this;
139
    }
140
141
    /**
142
     * @param string $varPath
143
     */
144
    public function setVarPath(string $varPath)
145
    {
146
        $this->varPath = $varPath;
147
148
        return $this;
149
    }
150
151
    /**
152
     * @param array $samples
153
     * @param array $targets
154
     */
155
    public function train(array $samples, array $targets)
156
    {
157
        $this->samples = array_merge($this->samples, $samples);
158
        $this->targets = array_merge($this->targets, $targets);
159
160
        $trainingSet = DataTransformer::trainingSet($this->samples, $this->targets, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR]));
161
        file_put_contents($trainingSetFileName = $this->varPath.uniqid('phpml', true), $trainingSet);
162
        $modelFileName = $trainingSetFileName.'-model';
163
164
        $command = $this->buildTrainCommand($trainingSetFileName, $modelFileName);
165
        $output = '';
166
        exec(escapeshellcmd($command), $output);
167
168
        $this->model = file_get_contents($modelFileName);
169
170
        unlink($trainingSetFileName);
171
        unlink($modelFileName);
172
    }
173
174
    /**
175
     * @return string
176
     */
177
    public function getModel()
178
    {
179
        return $this->model;
180
    }
181
182
    /**
183
     * @param array $samples
184
     *
185
     * @return array
186
     */
187
    public function predict(array $samples)
188
    {
189
        $testSet = DataTransformer::testSet($samples);
190
        file_put_contents($testSetFileName = $this->varPath.uniqid('phpml', true), $testSet);
191
        file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
192
        $outputFileName = $testSetFileName.'-output';
193
194
        $command = sprintf('%ssvm-predict%s %s %s %s', $this->binPath, $this->getOSExtension(), $testSetFileName, $modelFileName, $outputFileName);
195
        $output = '';
196
        exec(escapeshellcmd($command), $output);
197
198
        $predictions = file_get_contents($outputFileName);
199
200
        unlink($testSetFileName);
201
        unlink($modelFileName);
202
        unlink($outputFileName);
203
204
        if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
205
            $predictions = DataTransformer::predictions($predictions, $this->targets);
206
        } else {
207
            $predictions = explode(PHP_EOL, trim($predictions));
208
        }
209
210
        if (!is_array($samples[0])) {
211
            return $predictions[0];
212
        }
213
214
        return $predictions;
215
    }
216
217
    /**
218
     * @return string
219
     */
220
    private function getOSExtension()
221
    {
222
        $os = strtoupper(substr(PHP_OS, 0, 3));
223
        if ($os === 'WIN') {
224
            return '.exe';
225
        } elseif ($os === 'DAR') {
226
            return '-osx';
227
        }
228
229
        return '';
230
    }
231
232
    /**
233
     * @param $trainingSetFileName
234
     * @param $modelFileName
235
     *
236
     * @return string
237
     */
238
    private function buildTrainCommand(string $trainingSetFileName, string $modelFileName): string
239
    {
240
        return sprintf('%ssvm-train%s -s %s -t %s -c %s -n %s -d %s%s -r %s -p %s -m %s -e %s -h %d -b %d %s %s',
241
            $this->binPath,
242
            $this->getOSExtension(),
243
            $this->type,
244
            $this->kernel,
245
            $this->cost,
246
            $this->nu,
247
            $this->degree,
248
            $this->gamma !== null ? ' -g '.$this->gamma : '',
249
            $this->coef0,
250
            $this->epsilon,
251
            $this->cacheSize,
252
            $this->tolerance,
253
            $this->shrinking,
254
            $this->probabilityEstimates,
255
            escapeshellarg($trainingSetFileName),
256
            escapeshellarg($modelFileName)
257
        );
258
    }
259
}
260