1
|
|
|
<?php |
2
|
|
|
|
3
|
|
|
declare(strict_types=1); |
4
|
|
|
|
5
|
|
|
namespace Phpml\Classification\Ensemble; |
6
|
|
|
|
7
|
|
|
use Phpml\Classification\Linear\DecisionStump; |
8
|
|
|
use Phpml\Classification\WeightedClassifier; |
9
|
|
|
use Phpml\Math\Statistic\Mean; |
10
|
|
|
use Phpml\Math\Statistic\StandardDeviation; |
11
|
|
|
use Phpml\Classification\Classifier; |
12
|
|
|
use Phpml\Helper\Predictable; |
13
|
|
|
use Phpml\Helper\Trainable; |
14
|
|
|
|
15
|
|
|
class AdaBoost implements Classifier |
16
|
|
|
{ |
17
|
|
|
use Predictable, Trainable; |
18
|
|
|
|
19
|
|
|
/** |
20
|
|
|
* Actual labels given in the targets array |
21
|
|
|
* @var array |
22
|
|
|
*/ |
23
|
|
|
protected $labels = []; |
24
|
|
|
|
25
|
|
|
/** |
26
|
|
|
* @var int |
27
|
|
|
*/ |
28
|
|
|
protected $sampleCount; |
29
|
|
|
|
30
|
|
|
/** |
31
|
|
|
* @var int |
32
|
|
|
*/ |
33
|
|
|
protected $featureCount; |
34
|
|
|
|
35
|
|
|
/** |
36
|
|
|
* Number of maximum iterations to be done |
37
|
|
|
* |
38
|
|
|
* @var int |
39
|
|
|
*/ |
40
|
|
|
protected $maxIterations; |
41
|
|
|
|
42
|
|
|
/** |
43
|
|
|
* Sample weights |
44
|
|
|
* |
45
|
|
|
* @var array |
46
|
|
|
*/ |
47
|
|
|
protected $weights = []; |
48
|
|
|
|
49
|
|
|
/** |
50
|
|
|
* List of selected 'weak' classifiers |
51
|
|
|
* |
52
|
|
|
* @var array |
53
|
|
|
*/ |
54
|
|
|
protected $classifiers = []; |
55
|
|
|
|
56
|
|
|
/** |
57
|
|
|
* Base classifier weights |
58
|
|
|
* |
59
|
|
|
* @var array |
60
|
|
|
*/ |
61
|
|
|
protected $alpha = []; |
62
|
|
|
|
63
|
|
|
/** |
64
|
|
|
* @var string |
65
|
|
|
*/ |
66
|
|
|
protected $baseClassifier = DecisionStump::class; |
67
|
|
|
|
68
|
|
|
/** |
69
|
|
|
* @var array |
70
|
|
|
*/ |
71
|
|
|
protected $classifierOptions = []; |
72
|
|
|
|
73
|
|
|
/** |
74
|
|
|
* ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to |
75
|
|
|
* improve classification performance of 'weak' classifiers such as |
76
|
|
|
* DecisionStump (default base classifier of AdaBoost). |
77
|
|
|
* |
78
|
|
|
*/ |
79
|
|
|
public function __construct(int $maxIterations = 50) |
80
|
|
|
{ |
81
|
|
|
$this->maxIterations = $maxIterations; |
82
|
|
|
} |
83
|
|
|
|
84
|
|
|
/** |
85
|
|
|
* Sets the base classifier that will be used for boosting (default = DecisionStump) |
86
|
|
|
* |
87
|
|
|
* @param string $baseClassifier |
88
|
|
|
* @param array $classifierOptions |
89
|
|
|
*/ |
90
|
|
|
public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = []) |
91
|
|
|
{ |
92
|
|
|
$this->baseClassifier = $baseClassifier; |
93
|
|
|
$this->classifierOptions = $classifierOptions; |
94
|
|
|
} |
95
|
|
|
|
96
|
|
|
/** |
97
|
|
|
* @param array $samples |
98
|
|
|
* @param array $targets |
99
|
|
|
*/ |
100
|
|
|
public function train(array $samples, array $targets) |
101
|
|
|
{ |
102
|
|
|
// Initialize usual variables |
103
|
|
|
$this->labels = array_keys(array_count_values($targets)); |
104
|
|
|
if (count($this->labels) != 2) { |
105
|
|
|
throw new \Exception("AdaBoost is a binary classifier and can classify between two classes only"); |
106
|
|
|
} |
107
|
|
|
|
108
|
|
|
// Set all target values to either -1 or 1 |
109
|
|
|
$this->labels = [1 => $this->labels[0], -1 => $this->labels[1]]; |
110
|
|
View Code Duplication |
foreach ($targets as $target) { |
|
|
|
|
111
|
|
|
$this->targets[] = $target == $this->labels[1] ? 1 : -1; |
112
|
|
|
} |
113
|
|
|
|
114
|
|
|
$this->samples = array_merge($this->samples, $samples); |
115
|
|
|
$this->featureCount = count($samples[0]); |
116
|
|
|
$this->sampleCount = count($this->samples); |
117
|
|
|
|
118
|
|
|
// Initialize AdaBoost parameters |
119
|
|
|
$this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount); |
120
|
|
|
$this->classifiers = []; |
121
|
|
|
$this->alpha = []; |
122
|
|
|
|
123
|
|
|
// Execute the algorithm for a maximum number of iterations |
124
|
|
|
$currIter = 0; |
125
|
|
|
while ($this->maxIterations > $currIter++) { |
126
|
|
|
|
127
|
|
|
// Determine the best 'weak' classifier based on current weights |
128
|
|
|
$classifier = $this->getBestClassifier(); |
129
|
|
|
$errorRate = $this->evaluateClassifier($classifier); |
130
|
|
|
|
131
|
|
|
// Update alpha & weight values at each iteration |
132
|
|
|
$alpha = $this->calculateAlpha($errorRate); |
133
|
|
|
$this->updateWeights($classifier, $alpha); |
134
|
|
|
|
135
|
|
|
$this->classifiers[] = $classifier; |
136
|
|
|
$this->alpha[] = $alpha; |
137
|
|
|
} |
138
|
|
|
} |
139
|
|
|
|
140
|
|
|
/** |
141
|
|
|
* Returns the classifier with the lowest error rate with the |
142
|
|
|
* consideration of current sample weights |
143
|
|
|
* |
144
|
|
|
* @return Classifier |
145
|
|
|
*/ |
146
|
|
|
protected function getBestClassifier() |
147
|
|
|
{ |
148
|
|
|
$ref = new \ReflectionClass($this->baseClassifier); |
149
|
|
|
if ($this->classifierOptions) { |
|
|
|
|
150
|
|
|
$classifier = $ref->newInstanceArgs($this->classifierOptions); |
151
|
|
|
} else { |
152
|
|
|
$classifier = $ref->newInstance(); |
153
|
|
|
} |
154
|
|
|
|
155
|
|
|
if (is_subclass_of($classifier, WeightedClassifier::class)) { |
|
|
|
|
156
|
|
|
$classifier->setSampleWeights($this->weights); |
157
|
|
|
$classifier->train($this->samples, $this->targets); |
158
|
|
|
} else { |
159
|
|
|
list($samples, $targets) = $this->resample(); |
160
|
|
|
$classifier->train($samples, $targets); |
161
|
|
|
} |
162
|
|
|
|
163
|
|
|
return $classifier; |
164
|
|
|
} |
165
|
|
|
|
166
|
|
|
/** |
167
|
|
|
* Resamples the dataset in accordance with the weights and |
168
|
|
|
* returns the new dataset |
169
|
|
|
* |
170
|
|
|
* @return array |
171
|
|
|
*/ |
172
|
|
|
protected function resample() |
173
|
|
|
{ |
174
|
|
|
$weights = $this->weights; |
175
|
|
|
$std = StandardDeviation::population($weights); |
176
|
|
|
$mean= Mean::arithmetic($weights); |
177
|
|
|
$min = min($weights); |
178
|
|
|
$minZ= (int)round(($min - $mean) / $std); |
179
|
|
|
|
180
|
|
|
$samples = []; |
181
|
|
|
$targets = []; |
182
|
|
|
foreach ($weights as $index => $weight) { |
183
|
|
|
$z = (int)round(($weight - $mean) / $std) - $minZ + 1; |
184
|
|
View Code Duplication |
for ($i=0; $i < $z; $i++) { |
|
|
|
|
185
|
|
|
if (rand(0, 1) == 0) { |
186
|
|
|
continue; |
187
|
|
|
} |
188
|
|
|
$samples[] = $this->samples[$index]; |
189
|
|
|
$targets[] = $this->targets[$index]; |
190
|
|
|
} |
191
|
|
|
} |
192
|
|
|
|
193
|
|
|
return [$samples, $targets]; |
194
|
|
|
} |
195
|
|
|
|
196
|
|
|
/** |
197
|
|
|
* Evaluates the classifier and returns the classification error rate |
198
|
|
|
* |
199
|
|
|
* @param Classifier $classifier |
200
|
|
|
*/ |
201
|
|
|
protected function evaluateClassifier(Classifier $classifier) |
202
|
|
|
{ |
203
|
|
|
$total = (float) array_sum($this->weights); |
204
|
|
|
$wrong = 0; |
205
|
|
|
foreach ($this->samples as $index => $sample) { |
206
|
|
|
$predicted = $classifier->predict($sample); |
207
|
|
|
if ($predicted != $this->targets[$index]) { |
208
|
|
|
$wrong += $this->weights[$index]; |
209
|
|
|
} |
210
|
|
|
} |
211
|
|
|
|
212
|
|
|
return $wrong / $total; |
213
|
|
|
} |
214
|
|
|
|
215
|
|
|
/** |
216
|
|
|
* Calculates alpha of a classifier |
217
|
|
|
* |
218
|
|
|
* @param float $errorRate |
219
|
|
|
* @return float |
220
|
|
|
*/ |
221
|
|
|
protected function calculateAlpha(float $errorRate) |
222
|
|
|
{ |
223
|
|
|
if ($errorRate == 0) { |
224
|
|
|
$errorRate = 1e-10; |
225
|
|
|
} |
226
|
|
|
return 0.5 * log((1 - $errorRate) / $errorRate); |
227
|
|
|
} |
228
|
|
|
|
229
|
|
|
/** |
230
|
|
|
* Updates the sample weights |
231
|
|
|
* |
232
|
|
|
* @param Classifier $classifier |
233
|
|
|
* @param float $alpha |
234
|
|
|
*/ |
235
|
|
|
protected function updateWeights(Classifier $classifier, float $alpha) |
236
|
|
|
{ |
237
|
|
|
$sumOfWeights = array_sum($this->weights); |
238
|
|
|
$weightsT1 = []; |
239
|
|
|
foreach ($this->weights as $index => $weight) { |
240
|
|
|
$desired = $this->targets[$index]; |
241
|
|
|
$output = $classifier->predict($this->samples[$index]); |
242
|
|
|
|
243
|
|
|
$weight *= exp(-$alpha * $desired * $output) / $sumOfWeights; |
244
|
|
|
|
245
|
|
|
$weightsT1[] = $weight; |
246
|
|
|
} |
247
|
|
|
|
248
|
|
|
$this->weights = $weightsT1; |
249
|
|
|
} |
250
|
|
|
|
251
|
|
|
/** |
252
|
|
|
* @param array $sample |
253
|
|
|
* @return mixed |
254
|
|
|
*/ |
255
|
|
|
public function predictSample(array $sample) |
256
|
|
|
{ |
257
|
|
|
$sum = 0; |
258
|
|
|
foreach ($this->alpha as $index => $alpha) { |
259
|
|
|
$h = $this->classifiers[$index]->predict($sample); |
260
|
|
|
$sum += $h * $alpha; |
261
|
|
|
} |
262
|
|
|
|
263
|
|
|
return $this->labels[ $sum > 0 ? 1 : -1]; |
264
|
|
|
} |
265
|
|
|
} |
266
|
|
|
|
Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.
You can also find more detailed suggestions in the “Code” section of your repository.