1
|
|
|
<?php |
2
|
|
|
|
3
|
|
|
declare(strict_types=1); |
4
|
|
|
|
5
|
|
|
namespace Phpml\Classification\Linear; |
6
|
|
|
|
7
|
|
|
use Phpml\Helper\Predictable; |
8
|
|
|
use Phpml\Helper\Trainable; |
9
|
|
|
use Phpml\Classification\Classifier; |
10
|
|
|
use Phpml\Preprocessing\Normalizer; |
11
|
|
|
|
12
|
|
|
class Perceptron implements Classifier |
13
|
|
|
{ |
14
|
|
|
use Predictable; |
15
|
|
|
|
16
|
|
|
/** |
17
|
|
|
* The function whose result will be used to calculate the network error |
18
|
|
|
* for each instance |
19
|
|
|
* |
20
|
|
|
* @var string |
21
|
|
|
*/ |
22
|
|
|
protected static $errorFunction = 'outputClass'; |
23
|
|
|
|
24
|
|
|
/** |
25
|
|
|
* @var array |
26
|
|
|
*/ |
27
|
|
|
protected $samples = []; |
28
|
|
|
|
29
|
|
|
/** |
30
|
|
|
* @var array |
31
|
|
|
*/ |
32
|
|
|
protected $targets = []; |
33
|
|
|
|
34
|
|
|
/** |
35
|
|
|
* @var array |
36
|
|
|
*/ |
37
|
|
|
protected $labels = []; |
38
|
|
|
|
39
|
|
|
/** |
40
|
|
|
* @var int |
41
|
|
|
*/ |
42
|
|
|
protected $featureCount = 0; |
43
|
|
|
|
44
|
|
|
/** |
45
|
|
|
* @var array |
46
|
|
|
*/ |
47
|
|
|
protected $weights; |
48
|
|
|
|
49
|
|
|
/** |
50
|
|
|
* @var float |
51
|
|
|
*/ |
52
|
|
|
protected $learningRate; |
53
|
|
|
|
54
|
|
|
/** |
55
|
|
|
* @var int |
56
|
|
|
*/ |
57
|
|
|
protected $maxIterations; |
58
|
|
|
|
59
|
|
|
/** |
60
|
|
|
* @var Normalizer |
61
|
|
|
*/ |
62
|
|
|
protected $normalizer; |
63
|
|
|
|
64
|
|
|
/** |
65
|
|
|
* Minimum amount of change in the weights between iterations |
66
|
|
|
* that needs to be obtained to continue the training |
67
|
|
|
* |
68
|
|
|
* @var float |
69
|
|
|
*/ |
70
|
|
|
protected $threshold = 1e-5; |
71
|
|
|
|
72
|
|
|
/** |
73
|
|
|
* Initalize a perceptron classifier with given learning rate and maximum |
74
|
|
|
* number of iterations used while training the perceptron <br> |
75
|
|
|
* |
76
|
|
|
* Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive) <br> |
77
|
|
|
* Maximum number of iterations can be an integer value greater than 0 |
78
|
|
|
* @param int $learningRate |
79
|
|
|
* @param int $maxIterations |
80
|
|
|
*/ |
81
|
|
|
public function __construct(float $learningRate = 0.001, int $maxIterations = 1000, |
82
|
|
|
bool $normalizeInputs = true) |
83
|
|
|
{ |
84
|
|
|
if ($learningRate <= 0.0 || $learningRate > 1.0) { |
85
|
|
|
throw new \Exception("Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)"); |
86
|
|
|
} |
87
|
|
|
|
88
|
|
|
if ($maxIterations <= 0) { |
89
|
|
|
throw new \Exception("Maximum number of iterations should be an integer greater than 0"); |
90
|
|
|
} |
91
|
|
|
|
92
|
|
|
if ($normalizeInputs) { |
93
|
|
|
$this->normalizer = new Normalizer(Normalizer::NORM_STD); |
94
|
|
|
} |
95
|
|
|
|
96
|
|
|
$this->learningRate = $learningRate; |
97
|
|
|
$this->maxIterations = $maxIterations; |
98
|
|
|
} |
99
|
|
|
|
100
|
|
|
/** |
101
|
|
|
* Sets minimum value for the change in the weights |
102
|
|
|
* between iterations to continue the iterations.<br> |
103
|
|
|
* |
104
|
|
|
* If the weight change is less than given value then the |
105
|
|
|
* algorithm will stop training |
106
|
|
|
* |
107
|
|
|
* @param float $threshold |
108
|
|
|
*/ |
109
|
|
|
public function setChangeThreshold(float $threshold = 1e-5) |
110
|
|
|
{ |
111
|
|
|
$this->threshold = $threshold; |
112
|
|
|
} |
113
|
|
|
|
114
|
|
|
/** |
115
|
|
|
* @param array $samples |
116
|
|
|
* @param array $targets |
117
|
|
|
*/ |
118
|
|
|
public function train(array $samples, array $targets) |
119
|
|
|
{ |
120
|
|
|
$this->labels = array_keys(array_count_values($targets)); |
121
|
|
|
if (count($this->labels) > 2) { |
122
|
|
|
throw new \Exception("Perceptron is for binary (two-class) classification only"); |
123
|
|
|
} |
124
|
|
|
|
125
|
|
|
if ($this->normalizer) { |
126
|
|
|
$this->normalizer->transform($samples); |
127
|
|
|
} |
128
|
|
|
|
129
|
|
|
// Set all target values to either -1 or 1 |
130
|
|
|
$this->labels = [1 => $this->labels[0], -1 => $this->labels[1]]; |
131
|
|
View Code Duplication |
foreach ($targets as $target) { |
|
|
|
|
132
|
|
|
$this->targets[] = $target == $this->labels[1] ? 1 : -1; |
133
|
|
|
} |
134
|
|
|
|
135
|
|
|
// Set samples and feature count vars |
136
|
|
|
$this->samples = array_merge($this->samples, $samples); |
137
|
|
|
$this->featureCount = count($this->samples[0]); |
138
|
|
|
|
139
|
|
|
// Init weights with random values |
140
|
|
|
$this->weights = array_fill(0, $this->featureCount + 1, 0); |
141
|
|
|
foreach ($this->weights as &$weight) { |
142
|
|
|
$weight = rand() / (float) getrandmax(); |
143
|
|
|
} |
144
|
|
|
// Do training |
145
|
|
|
$this->runTraining(); |
146
|
|
|
} |
147
|
|
|
|
148
|
|
|
/** |
149
|
|
|
* Adapts the weights with respect to given samples and targets |
150
|
|
|
* by use of perceptron learning rule |
151
|
|
|
*/ |
152
|
|
|
protected function runTraining() |
153
|
|
|
{ |
154
|
|
|
$currIter = 0; |
155
|
|
|
$bestWeights = null; |
156
|
|
|
$bestScore = count($this->samples); |
157
|
|
|
$bestWeightIter = 0; |
|
|
|
|
158
|
|
|
|
159
|
|
|
while ($this->maxIterations > $currIter++) { |
160
|
|
|
$weights = $this->weights; |
161
|
|
|
$misClassified = 0; |
162
|
|
|
foreach ($this->samples as $index => $sample) { |
163
|
|
|
$target = $this->targets[$index]; |
164
|
|
|
$prediction = $this->{static::$errorFunction}($sample); |
165
|
|
|
$update = $target - $prediction; |
166
|
|
|
if ($target != $prediction) { |
167
|
|
|
$misClassified++; |
168
|
|
|
} |
169
|
|
|
// Update bias |
170
|
|
|
$this->weights[0] += $update * $this->learningRate; // Bias |
171
|
|
|
// Update other weights |
172
|
|
|
for ($i=1; $i <= $this->featureCount; $i++) { |
173
|
|
|
$this->weights[$i] += $update * $sample[$i - 1] * $this->learningRate; |
174
|
|
|
} |
175
|
|
|
} |
176
|
|
|
|
177
|
|
|
// Save the best weights in the "pocket" so that |
178
|
|
|
// any future weights worse than this will be disregarded |
179
|
|
|
if ($bestWeights == null || $misClassified <= $bestScore) { |
180
|
|
|
$bestWeights = $weights; |
181
|
|
|
$bestScore = $misClassified; |
182
|
|
|
$bestWeightIter = $currIter; |
|
|
|
|
183
|
|
|
} |
184
|
|
|
|
185
|
|
|
// Check for early stop |
186
|
|
|
if ($this->earlyStop($weights)) { |
187
|
|
|
break; |
188
|
|
|
} |
189
|
|
|
} |
190
|
|
|
|
191
|
|
|
// The weights in the pocket are better than or equal to the last state |
192
|
|
|
// so, we use these weights |
193
|
|
|
$this->weights = $bestWeights; |
|
|
|
|
194
|
|
|
} |
195
|
|
|
|
196
|
|
|
/** |
197
|
|
|
* @param array $oldWeights |
198
|
|
|
* |
199
|
|
|
* @return boolean |
200
|
|
|
*/ |
201
|
|
|
protected function earlyStop($oldWeights) |
202
|
|
|
{ |
203
|
|
|
// Check for early stop: No change larger than 1e-5 |
204
|
|
|
$diff = array_map( |
205
|
|
|
function ($w1, $w2) { |
206
|
|
|
return abs($w1 - $w2) > 1e-5 ? 1 : 0; |
207
|
|
|
}, |
208
|
|
|
$oldWeights, $this->weights); |
209
|
|
|
|
210
|
|
|
if (array_sum($diff) == 0) { |
211
|
|
|
return true; |
212
|
|
|
} |
213
|
|
|
|
214
|
|
|
return false; |
215
|
|
|
} |
216
|
|
|
|
217
|
|
|
/** |
218
|
|
|
* Calculates net output of the network as a float value for the given input |
219
|
|
|
* |
220
|
|
|
* @param array $sample |
221
|
|
|
* @return int |
222
|
|
|
*/ |
223
|
|
|
protected function output(array $sample) |
224
|
|
|
{ |
225
|
|
|
$sum = 0; |
226
|
|
|
foreach ($this->weights as $index => $w) { |
227
|
|
|
if ($index == 0) { |
228
|
|
|
$sum += $w; |
229
|
|
|
} else { |
230
|
|
|
$sum += $w * $sample[$index - 1]; |
231
|
|
|
} |
232
|
|
|
} |
233
|
|
|
|
234
|
|
|
return $sum; |
235
|
|
|
} |
236
|
|
|
|
237
|
|
|
/** |
238
|
|
|
* Returns the class value (either -1 or 1) for the given input |
239
|
|
|
* |
240
|
|
|
* @param array $sample |
241
|
|
|
* @return int |
242
|
|
|
*/ |
243
|
|
|
protected function outputClass(array $sample) |
244
|
|
|
{ |
245
|
|
|
return $this->output($sample) > 0 ? 1 : -1; |
246
|
|
|
} |
247
|
|
|
|
248
|
|
|
/** |
249
|
|
|
* @param array $sample |
250
|
|
|
* @return mixed |
251
|
|
|
*/ |
252
|
|
|
protected function predictSample(array $sample) |
253
|
|
|
{ |
254
|
|
|
if ($this->normalizer) { |
255
|
|
|
$samples = [$sample]; |
256
|
|
|
$this->normalizer->transform($samples); |
257
|
|
|
$sample = $samples[0]; |
258
|
|
|
} |
259
|
|
|
|
260
|
|
|
$predictedClass = $this->outputClass($sample); |
261
|
|
|
|
262
|
|
|
return $this->labels[ $predictedClass ]; |
263
|
|
|
} |
264
|
|
|
} |
265
|
|
|
|
Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.
You can also find more detailed suggestions in the “Code” section of your repository.