1
|
|
|
<?php |
2
|
|
|
|
3
|
|
|
declare(strict_types=1); |
4
|
|
|
|
5
|
|
|
namespace Phpml\Classification\Linear; |
6
|
|
|
|
7
|
|
|
use Phpml\Helper\Predictable; |
8
|
|
|
use Phpml\Helper\Trainable; |
9
|
|
|
use Phpml\Classification\Classifier; |
10
|
|
|
use Phpml\Preprocessing\Normalizer; |
11
|
|
|
|
12
|
|
|
class Perceptron implements Classifier |
13
|
|
|
{ |
14
|
|
|
use Predictable; |
15
|
|
|
|
16
|
|
|
/** |
17
|
|
|
* The function whose result will be used to calculate the network error |
18
|
|
|
* for each instance |
19
|
|
|
* |
20
|
|
|
* @var string |
21
|
|
|
*/ |
22
|
|
|
protected static $errorFunction = 'outputClass'; |
23
|
|
|
|
24
|
|
|
/** |
25
|
|
|
* @var array |
26
|
|
|
*/ |
27
|
|
|
protected $samples = []; |
28
|
|
|
|
29
|
|
|
/** |
30
|
|
|
* @var array |
31
|
|
|
*/ |
32
|
|
|
protected $targets = []; |
33
|
|
|
|
34
|
|
|
/** |
35
|
|
|
* @var array |
36
|
|
|
*/ |
37
|
|
|
protected $labels = []; |
38
|
|
|
|
39
|
|
|
/** |
40
|
|
|
* @var int |
41
|
|
|
*/ |
42
|
|
|
protected $featureCount = 0; |
43
|
|
|
|
44
|
|
|
/** |
45
|
|
|
* @var array |
46
|
|
|
*/ |
47
|
|
|
protected $weights; |
48
|
|
|
|
49
|
|
|
/** |
50
|
|
|
* @var float |
51
|
|
|
*/ |
52
|
|
|
protected $learningRate; |
53
|
|
|
|
54
|
|
|
/** |
55
|
|
|
* @var int |
56
|
|
|
*/ |
57
|
|
|
protected $maxIterations; |
58
|
|
|
|
59
|
|
|
/** |
60
|
|
|
* @var Normalizer |
61
|
|
|
*/ |
62
|
|
|
protected $normalizer; |
63
|
|
|
|
64
|
|
|
/** |
65
|
|
|
* Initalize a perceptron classifier with given learning rate and maximum |
66
|
|
|
* number of iterations used while training the perceptron <br> |
67
|
|
|
* |
68
|
|
|
* Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive) <br> |
69
|
|
|
* Maximum number of iterations can be an integer value greater than 0 |
70
|
|
|
* @param int $learningRate |
71
|
|
|
* @param int $maxIterations |
72
|
|
|
*/ |
73
|
|
|
public function __construct(float $learningRate = 0.001, int $maxIterations = 1000, |
74
|
|
|
bool $normalizeInputs = true) |
75
|
|
|
{ |
76
|
|
|
if ($learningRate <= 0.0 || $learningRate > 1.0) { |
77
|
|
|
throw new \Exception("Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)"); |
78
|
|
|
} |
79
|
|
|
|
80
|
|
|
if ($maxIterations <= 0) { |
81
|
|
|
throw new \Exception("Maximum number of iterations should be an integer greater than 0"); |
82
|
|
|
} |
83
|
|
|
|
84
|
|
|
if ($normalizeInputs) { |
85
|
|
|
$this->normalizer = new Normalizer(Normalizer::NORM_STD); |
86
|
|
|
} |
87
|
|
|
|
88
|
|
|
$this->learningRate = $learningRate; |
89
|
|
|
$this->maxIterations = $maxIterations; |
90
|
|
|
} |
91
|
|
|
|
92
|
|
|
/** |
93
|
|
|
* @param array $samples |
94
|
|
|
* @param array $targets |
95
|
|
|
*/ |
96
|
|
|
public function train(array $samples, array $targets) |
97
|
|
|
{ |
98
|
|
|
$this->labels = array_keys(array_count_values($targets)); |
99
|
|
|
if (count($this->labels) > 2) { |
100
|
|
|
throw new \Exception("Perceptron is for only binary (two-class) classification"); |
101
|
|
|
} |
102
|
|
|
|
103
|
|
|
if ($this->normalizer) { |
104
|
|
|
$this->normalizer->transform($samples); |
105
|
|
|
} |
106
|
|
|
|
107
|
|
|
// Set all target values to either -1 or 1 |
108
|
|
|
$this->labels = [1 => $this->labels[0], -1 => $this->labels[1]]; |
109
|
|
View Code Duplication |
foreach ($targets as $target) { |
|
|
|
|
110
|
|
|
$this->targets[] = $target == $this->labels[1] ? 1 : -1; |
111
|
|
|
} |
112
|
|
|
|
113
|
|
|
// Set samples and feature count vars |
114
|
|
|
$this->samples = array_merge($this->samples, $samples); |
115
|
|
|
$this->featureCount = count($this->samples[0]); |
116
|
|
|
|
117
|
|
|
// Init weights with random values |
118
|
|
|
$this->weights = array_fill(0, $this->featureCount + 1, 0); |
119
|
|
|
foreach ($this->weights as &$weight) { |
120
|
|
|
$weight = rand() / (float) getrandmax(); |
121
|
|
|
} |
122
|
|
|
// Do training |
123
|
|
|
$this->runTraining(); |
124
|
|
|
} |
125
|
|
|
|
126
|
|
|
/** |
127
|
|
|
* Adapts the weights with respect to given samples and targets |
128
|
|
|
* by use of perceptron learning rule |
129
|
|
|
*/ |
130
|
|
|
protected function runTraining() |
131
|
|
|
{ |
132
|
|
|
$currIter = 0; |
133
|
|
|
while ($this->maxIterations > $currIter++) { |
134
|
|
|
foreach ($this->samples as $index => $sample) { |
135
|
|
|
$target = $this->targets[$index]; |
136
|
|
|
$prediction = $this->{static::$errorFunction}($sample); |
137
|
|
|
$update = $target - $prediction; |
138
|
|
|
// Update bias |
139
|
|
|
$this->weights[0] += $update * $this->learningRate; // Bias |
140
|
|
|
// Update other weights |
141
|
|
|
for ($i=1; $i <= $this->featureCount; $i++) { |
142
|
|
|
$this->weights[$i] += $update * $sample[$i - 1] * $this->learningRate; |
143
|
|
|
} |
144
|
|
|
} |
145
|
|
|
} |
146
|
|
|
} |
147
|
|
|
|
148
|
|
|
/** |
149
|
|
|
* Calculates net output of the network as a float value for the given input |
150
|
|
|
* |
151
|
|
|
* @param array $sample |
152
|
|
|
* @return int |
153
|
|
|
*/ |
154
|
|
|
protected function output(array $sample) |
155
|
|
|
{ |
156
|
|
|
$sum = 0; |
157
|
|
|
foreach ($this->weights as $index => $w) { |
158
|
|
|
if ($index == 0) { |
159
|
|
|
$sum += $w; |
160
|
|
|
} else { |
161
|
|
|
$sum += $w * $sample[$index - 1]; |
162
|
|
|
} |
163
|
|
|
} |
164
|
|
|
|
165
|
|
|
return $sum; |
166
|
|
|
} |
167
|
|
|
|
168
|
|
|
/** |
169
|
|
|
* Returns the class value (either -1 or 1) for the given input |
170
|
|
|
* |
171
|
|
|
* @param array $sample |
172
|
|
|
* @return int |
173
|
|
|
*/ |
174
|
|
|
protected function outputClass(array $sample) |
175
|
|
|
{ |
176
|
|
|
return $this->output($sample) > 0 ? 1 : -1; |
177
|
|
|
} |
178
|
|
|
|
179
|
|
|
/** |
180
|
|
|
* @param array $sample |
181
|
|
|
* @return mixed |
182
|
|
|
*/ |
183
|
|
|
protected function predictSample(array $sample) |
184
|
|
|
{ |
185
|
|
|
if ($this->normalizer) { |
186
|
|
|
$samples = [$sample]; |
187
|
|
|
$this->normalizer->transform($samples); |
188
|
|
|
$sample = $samples[0]; |
189
|
|
|
} |
190
|
|
|
|
191
|
|
|
$predictedClass = $this->outputClass($sample); |
192
|
|
|
|
193
|
|
|
return $this->labels[ $predictedClass ]; |
194
|
|
|
} |
195
|
|
|
} |
196
|
|
|
|
Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.
You can also find more detailed suggestions in the “Code” section of your repository.