Completed
Push — develop ( 95bfc8...365a9b )
by Arkadiusz
02:42
created

SupportVectorMachine::buildTrainCommand()   A

Complexity

Conditions 2
Paths 1

Size

Total Lines 21
Code Lines 18

Duplication

Lines 0
Ratio 0 %
Metric Value
dl 0
loc 21
rs 9.3142
cc 2
eloc 18
nc 1
nop 2
1
<?php
2
3
declare (strict_types = 1);
4
5
namespace Phpml\SupportVectorMachine;
6
7
class SupportVectorMachine
8
{
9
    /**
10
     * @var int
11
     */
12
    private $type;
13
14
    /**
15
     * @var int
16
     */
17
    private $kernel;
18
19
    /**
20
     * @var float
21
     */
22
    private $cost;
23
24
    /**
25
     * @var float
26
     */
27
    private $nu;
28
29
    /**
30
     * @var int
31
     */
32
    private $degree;
33
34
    /**
35
     * @var float
36
     */
37
    private $gamma;
38
39
    /**
40
     * @var float
41
     */
42
    private $coef0;
43
44
    /**
45
     * @var float
46
     */
47
    private $epsilon;
48
49
    /**
50
     * @var float
51
     */
52
    private $tolerance;
53
54
    /**
55
     * @var int
56
     */
57
    private $cacheSize;
58
59
    /**
60
     * @var bool
61
     */
62
    private $shrinking;
63
64
    /**
65
     * @var bool
66
     */
67
    private $probabilityEstimates;
68
69
    /**
70
     * @var string
71
     */
72
    private $binPath;
73
74
    /**
75
     * @var string
76
     */
77
    private $varPath;
78
79
    /**
80
     * @var string
81
     */
82
    private $model;
83
84
    /**
85
     * @var array
86
     */
87
    private $labels;
88
89
    /**
90
     * @param int        $type
91
     * @param int        $kernel
92
     * @param float      $cost
93
     * @param float      $nu
94
     * @param int        $degree
95
     * @param float|null $gamma
96
     * @param float      $coef0
97
     * @param float      $epsilon
98
     * @param float      $tolerance
99
     * @param int        $cacheSize
100
     * @param bool       $shrinking
101
     * @param bool       $probabilityEstimates
102
     */
103
    public function __construct(
104
        int $type, int $kernel, float $cost = 1.0, float $nu = 0.5, int $degree = 3,
105
        float $gamma = null, float $coef0 = 0.0, float $epsilon = 0.1, float $tolerance = 0.001,
106
        int $cacheSize = 100, bool $shrinking = true, bool $probabilityEstimates = false
107
    ) {
108
        $this->type = $type;
109
        $this->kernel = $kernel;
110
        $this->cost = $cost;
111
        $this->nu = $nu;
112
        $this->degree = $degree;
113
        $this->gamma = $gamma;
114
        $this->coef0 = $coef0;
115
        $this->epsilon = $epsilon;
116
        $this->tolerance = $tolerance;
117
        $this->cacheSize = $cacheSize;
118
        $this->shrinking = $shrinking;
119
        $this->probabilityEstimates = $probabilityEstimates;
120
121
        $rootPath = realpath(implode(DIRECTORY_SEPARATOR, [dirname(__FILE__), '..', '..', '..'])).DIRECTORY_SEPARATOR;
122
123
        $this->binPath = $rootPath.'bin'.DIRECTORY_SEPARATOR.'libsvm'.DIRECTORY_SEPARATOR;
124
        $this->varPath = $rootPath.'var'.DIRECTORY_SEPARATOR;
125
    }
126
127
    /**
128
     * @param array $samples
129
     * @param array $labels
130
     */
131
    public function train(array $samples, array $labels)
132
    {
133
        $this->labels = $labels;
134
        $trainingSet = DataTransformer::trainingSet($samples, $labels, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR]));
135
        file_put_contents($trainingSetFileName = $this->varPath.uniqid(), $trainingSet);
136
        $modelFileName = $trainingSetFileName.'-model';
137
138
        $command = $this->buildTrainCommand($trainingSetFileName, $modelFileName);
139
        $output = '';
140
        exec(escapeshellcmd($command), $output);
141
142
        $this->model = file_get_contents($modelFileName);
143
144
        unlink($trainingSetFileName);
145
        unlink($modelFileName);
146
    }
147
148
    /**
149
     * @return string
150
     */
151
    public function getModel()
152
    {
153
        return $this->model;
154
    }
155
156
    /**
157
     * @param array $samples
158
     *
159
     * @return array
160
     */
161
    public function predict(array $samples)
162
    {
163
        $testSet = DataTransformer::testSet($samples);
164
        file_put_contents($testSetFileName = $this->varPath.uniqid(), $testSet);
165
        file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
166
        $outputFileName = $testSetFileName.'-output';
167
168
        $command = sprintf('%ssvm-predict%s %s %s %s', $this->binPath, $this->getOSExtension(), $testSetFileName, $modelFileName, $outputFileName);
169
        $output = '';
170
        exec(escapeshellcmd($command), $output);
171
172
        $predictions = file_get_contents($outputFileName);
173
174
        unlink($testSetFileName);
175
        unlink($modelFileName);
176
        unlink($outputFileName);
177
178
        if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
179
            $predictions = DataTransformer::predictions($predictions, $this->labels);
180
        } else {
181
            $predictions = explode(PHP_EOL, trim($predictions));
182
        }
183
184
        if (!is_array($samples[0])) {
185
            return $predictions[0];
186
        }
187
188
        return $predictions;
189
    }
190
191
    /**
192
     * @return string
193
     */
194
    private function getOSExtension()
195
    {
196
        if (strtoupper(substr(PHP_OS, 0, 3)) === 'WIN') {
197
            return '.exe';
198
        }
199
200
        return '';
201
    }
202
203
    /**
204
     * @param $trainingSetFileName
205
     * @param $modelFileName
206
     *
207
     * @return string
208
     */
209
    private function buildTrainCommand(string $trainingSetFileName, string $modelFileName): string
210
    {
211
        return sprintf('%ssvm-train%s -s %s -t %s -c %s -n %s -d %s%s -r %s -p %s -m %s -e %s -h %d -b %d \'%s\' \'%s\'',
212
            $this->binPath,
213
            $this->getOSExtension(),
214
            $this->type,
215
            $this->kernel,
216
            $this->cost,
217
            $this->nu,
218
            $this->degree,
219
            $this->gamma !== null ? ' -g '.$this->gamma : '',
220
            $this->coef0,
221
            $this->epsilon,
222
            $this->cacheSize,
223
            $this->tolerance,
224
            $this->shrinking,
225
            $this->probabilityEstimates,
226
            $trainingSetFileName,
227
            $modelFileName
228
        );
229
    }
230
}
231