1
|
|
|
<?php |
2
|
|
|
|
3
|
|
|
declare(strict_types=1); |
4
|
|
|
|
5
|
|
|
namespace Phpml\Classification\Linear; |
6
|
|
|
|
7
|
|
|
use Phpml\Helper\Predictable; |
8
|
|
|
use Phpml\Helper\OneVsRest; |
9
|
|
|
use Phpml\Classification\Classifier; |
10
|
|
|
use Phpml\Preprocessing\Normalizer; |
11
|
|
|
|
12
|
|
|
class Perceptron implements Classifier |
13
|
|
|
{ |
14
|
|
|
use Predictable, OneVsRest; |
15
|
|
|
|
16
|
|
|
/** |
17
|
|
|
* The function whose result will be used to calculate the network error |
18
|
|
|
* for each instance |
19
|
|
|
* |
20
|
|
|
* @var string |
21
|
|
|
*/ |
22
|
|
|
protected static $errorFunction = 'outputClass'; |
23
|
|
|
|
24
|
|
|
/** |
25
|
|
|
* @var array |
26
|
|
|
*/ |
27
|
|
|
protected $samples = []; |
28
|
|
|
|
29
|
|
|
/** |
30
|
|
|
* @var array |
31
|
|
|
*/ |
32
|
|
|
protected $targets = []; |
33
|
|
|
|
34
|
|
|
/** |
35
|
|
|
* @var array |
36
|
|
|
*/ |
37
|
|
|
protected $labels = []; |
38
|
|
|
|
39
|
|
|
/** |
40
|
|
|
* @var int |
41
|
|
|
*/ |
42
|
|
|
protected $featureCount = 0; |
43
|
|
|
|
44
|
|
|
/** |
45
|
|
|
* @var array |
46
|
|
|
*/ |
47
|
|
|
protected $weights; |
48
|
|
|
|
49
|
|
|
/** |
50
|
|
|
* @var float |
51
|
|
|
*/ |
52
|
|
|
protected $learningRate; |
53
|
|
|
|
54
|
|
|
/** |
55
|
|
|
* @var int |
56
|
|
|
*/ |
57
|
|
|
protected $maxIterations; |
58
|
|
|
|
59
|
|
|
/** |
60
|
|
|
* @var Normalizer |
61
|
|
|
*/ |
62
|
|
|
protected $normalizer; |
63
|
|
|
|
64
|
|
|
/** |
65
|
|
|
* Minimum amount of change in the weights between iterations |
66
|
|
|
* that needs to be obtained to continue the training |
67
|
|
|
* |
68
|
|
|
* @var float |
69
|
|
|
*/ |
70
|
|
|
protected $threshold = 1e-5; |
71
|
|
|
|
72
|
|
|
/** |
73
|
|
|
* Initalize a perceptron classifier with given learning rate and maximum |
74
|
|
|
* number of iterations used while training the perceptron <br> |
75
|
|
|
* |
76
|
|
|
* Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive) <br> |
77
|
|
|
* Maximum number of iterations can be an integer value greater than 0 |
78
|
|
|
* @param int $learningRate |
79
|
|
|
* @param int $maxIterations |
80
|
|
|
*/ |
81
|
|
|
public function __construct(float $learningRate = 0.001, int $maxIterations = 1000, |
82
|
|
|
bool $normalizeInputs = true) |
83
|
|
|
{ |
84
|
|
|
if ($learningRate <= 0.0 || $learningRate > 1.0) { |
85
|
|
|
throw new \Exception("Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)"); |
86
|
|
|
} |
87
|
|
|
|
88
|
|
|
if ($maxIterations <= 0) { |
89
|
|
|
throw new \Exception("Maximum number of iterations should be an integer greater than 0"); |
90
|
|
|
} |
91
|
|
|
|
92
|
|
|
if ($normalizeInputs) { |
93
|
|
|
$this->normalizer = new Normalizer(Normalizer::NORM_STD); |
94
|
|
|
} |
95
|
|
|
|
96
|
|
|
$this->learningRate = $learningRate; |
97
|
|
|
$this->maxIterations = $maxIterations; |
98
|
|
|
} |
99
|
|
|
|
100
|
|
|
/** |
101
|
|
|
* Sets minimum value for the change in the weights |
102
|
|
|
* between iterations to continue the iterations.<br> |
103
|
|
|
* |
104
|
|
|
* If the weight change is less than given value then the |
105
|
|
|
* algorithm will stop training |
106
|
|
|
* |
107
|
|
|
* @param float $threshold |
108
|
|
|
*/ |
109
|
|
|
public function setChangeThreshold(float $threshold = 1e-5) |
110
|
|
|
{ |
111
|
|
|
$this->threshold = $threshold; |
112
|
|
|
} |
113
|
|
|
|
114
|
|
|
/** |
115
|
|
|
* @param array $samples |
116
|
|
|
* @param array $targets |
117
|
|
|
*/ |
118
|
|
|
public function trainBinary(array $samples, array $targets) |
119
|
|
|
{ |
120
|
|
|
$this->labels = array_keys(array_count_values($targets)); |
121
|
|
|
if (count($this->labels) > 2) { |
122
|
|
|
throw new \Exception("Perceptron is for binary (two-class) classification only"); |
123
|
|
|
} |
124
|
|
|
|
125
|
|
|
if ($this->normalizer) { |
126
|
|
|
$this->normalizer->transform($samples); |
127
|
|
|
} |
128
|
|
|
|
129
|
|
|
// Set all target values to either -1 or 1 |
130
|
|
|
$this->labels = [1 => $this->labels[0], -1 => $this->labels[1]]; |
131
|
|
View Code Duplication |
foreach ($targets as $target) { |
|
|
|
|
132
|
|
|
$this->targets[] = strval($target) == strval($this->labels[1]) ? 1 : -1; |
133
|
|
|
} |
134
|
|
|
|
135
|
|
|
// Set samples and feature count vars |
136
|
|
|
$this->samples = array_merge($this->samples, $samples); |
137
|
|
|
$this->featureCount = count($this->samples[0]); |
138
|
|
|
|
139
|
|
|
// Init weights with random values |
140
|
|
|
$this->weights = array_fill(0, $this->featureCount + 1, 0); |
141
|
|
|
foreach ($this->weights as &$weight) { |
142
|
|
|
$weight = rand() / (float) getrandmax(); |
143
|
|
|
} |
144
|
|
|
// Do training |
145
|
|
|
$this->runTraining(); |
146
|
|
|
} |
147
|
|
|
|
148
|
|
|
/** |
149
|
|
|
* Adapts the weights with respect to given samples and targets |
150
|
|
|
* by use of perceptron learning rule |
151
|
|
|
*/ |
152
|
|
|
protected function runTraining() |
153
|
|
|
{ |
154
|
|
|
$currIter = 0; |
155
|
|
|
$bestWeights = null; |
156
|
|
|
$bestScore = count($this->samples); |
157
|
|
|
$bestWeightIter = 0; |
|
|
|
|
158
|
|
|
|
159
|
|
|
while ($this->maxIterations > $currIter++) { |
160
|
|
|
$weights = $this->weights; |
161
|
|
|
$misClassified = 0; |
162
|
|
|
foreach ($this->samples as $index => $sample) { |
163
|
|
|
$target = $this->targets[$index]; |
164
|
|
|
$prediction = $this->{static::$errorFunction}($sample); |
165
|
|
|
$update = $target - $prediction; |
166
|
|
|
if ($target != $prediction) { |
167
|
|
|
$misClassified++; |
168
|
|
|
} |
169
|
|
|
// Update bias |
170
|
|
|
$this->weights[0] += $update * $this->learningRate; // Bias |
171
|
|
|
// Update other weights |
172
|
|
|
for ($i=1; $i <= $this->featureCount; $i++) { |
173
|
|
|
$this->weights[$i] += $update * $sample[$i - 1] * $this->learningRate; |
174
|
|
|
} |
175
|
|
|
} |
176
|
|
|
|
177
|
|
|
// Save the best weights in the "pocket" so that |
178
|
|
|
// any future weights worse than this will be disregarded |
179
|
|
|
if ($bestWeights == null || $misClassified <= $bestScore) { |
180
|
|
|
$bestWeights = $weights; |
181
|
|
|
$bestScore = $misClassified; |
182
|
|
|
$bestWeightIter = $currIter; |
|
|
|
|
183
|
|
|
} |
184
|
|
|
|
185
|
|
|
// Check for early stop |
186
|
|
|
if ($this->earlyStop($weights)) { |
187
|
|
|
break; |
188
|
|
|
} |
189
|
|
|
} |
190
|
|
|
|
191
|
|
|
// The weights in the pocket are better than or equal to the last state |
192
|
|
|
// so, we use these weights |
193
|
|
|
$this->weights = $bestWeights; |
|
|
|
|
194
|
|
|
} |
195
|
|
|
|
196
|
|
|
/** |
197
|
|
|
* @param array $oldWeights |
198
|
|
|
* |
199
|
|
|
* @return boolean |
200
|
|
|
*/ |
201
|
|
|
protected function earlyStop($oldWeights) |
202
|
|
|
{ |
203
|
|
|
// Check for early stop: No change larger than 1e-5 |
204
|
|
|
$diff = array_map( |
205
|
|
|
function ($w1, $w2) { |
206
|
|
|
return abs($w1 - $w2) > 1e-5 ? 1 : 0; |
207
|
|
|
}, |
208
|
|
|
$oldWeights, $this->weights); |
209
|
|
|
|
210
|
|
|
if (array_sum($diff) == 0) { |
211
|
|
|
return true; |
212
|
|
|
} |
213
|
|
|
|
214
|
|
|
return false; |
215
|
|
|
} |
216
|
|
|
|
217
|
|
|
/** |
218
|
|
|
* Checks if the sample should be normalized and if so, returns the |
219
|
|
|
* normalized sample |
220
|
|
|
* |
221
|
|
|
* @param array $sample |
222
|
|
|
* |
223
|
|
|
* @return array |
224
|
|
|
*/ |
225
|
|
|
protected function checkNormalizedSample(array $sample) |
226
|
|
|
{ |
227
|
|
|
if ($this->normalizer) { |
228
|
|
|
$samples = [$sample]; |
229
|
|
|
$this->normalizer->transform($samples); |
230
|
|
|
$sample = $samples[0]; |
231
|
|
|
} |
232
|
|
|
|
233
|
|
|
return $sample; |
234
|
|
|
} |
235
|
|
|
|
236
|
|
|
/** |
237
|
|
|
* Calculates net output of the network as a float value for the given input |
238
|
|
|
* |
239
|
|
|
* @param array $sample |
240
|
|
|
* @return int |
241
|
|
|
*/ |
242
|
|
|
protected function output(array $sample) |
243
|
|
|
{ |
244
|
|
|
$sum = 0; |
245
|
|
|
foreach ($this->weights as $index => $w) { |
246
|
|
|
if ($index == 0) { |
247
|
|
|
$sum += $w; |
248
|
|
|
} else { |
249
|
|
|
$sum += $w * $sample[$index - 1]; |
250
|
|
|
} |
251
|
|
|
} |
252
|
|
|
|
253
|
|
|
return $sum; |
254
|
|
|
} |
255
|
|
|
|
256
|
|
|
/** |
257
|
|
|
* Returns the class value (either -1 or 1) for the given input |
258
|
|
|
* |
259
|
|
|
* @param array $sample |
260
|
|
|
* @return int |
261
|
|
|
*/ |
262
|
|
|
protected function outputClass(array $sample) |
263
|
|
|
{ |
264
|
|
|
return $this->output($sample) > 0 ? 1 : -1; |
265
|
|
|
} |
266
|
|
|
|
267
|
|
|
/** |
268
|
|
|
* Returns the probability of the sample of belonging to the given label. |
269
|
|
|
* |
270
|
|
|
* The probability is simply taken as the distance of the sample |
271
|
|
|
* to the decision plane. |
272
|
|
|
* |
273
|
|
|
* @param array $sample |
274
|
|
|
* @param mixed $label |
275
|
|
|
*/ |
276
|
|
|
protected function predictProbability(array $sample, $label) |
277
|
|
|
{ |
278
|
|
|
$predicted = $this->predictSampleBinary($sample); |
279
|
|
|
|
280
|
|
|
if (strval($predicted) == strval($label)) { |
281
|
|
|
$sample = $this->checkNormalizedSample($sample); |
282
|
|
|
return abs($this->output($sample)); |
283
|
|
|
} |
284
|
|
|
|
285
|
|
|
return 0.0; |
286
|
|
|
} |
287
|
|
|
|
288
|
|
|
/** |
289
|
|
|
* @param array $sample |
290
|
|
|
* @return mixed |
291
|
|
|
*/ |
292
|
|
|
protected function predictSampleBinary(array $sample) |
293
|
|
|
{ |
294
|
|
|
$sample = $this->checkNormalizedSample($sample); |
295
|
|
|
|
296
|
|
|
$predictedClass = $this->outputClass($sample); |
297
|
|
|
|
298
|
|
|
return $this->labels[ $predictedClass ]; |
299
|
|
|
} |
300
|
|
|
} |
301
|
|
|
|
Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.
You can also find more detailed suggestions in the “Code” section of your repository.