1
|
|
|
<?php |
2
|
|
|
|
3
|
|
|
declare(strict_types=1); |
4
|
|
|
|
5
|
|
|
namespace Phpml\Classification\Ensemble; |
6
|
|
|
|
7
|
|
|
use Phpml\Classification\Classifier; |
8
|
|
|
use Phpml\Classification\Linear\DecisionStump; |
9
|
|
|
use Phpml\Classification\WeightedClassifier; |
10
|
|
|
use Phpml\Exception\InvalidArgumentException; |
11
|
|
|
use Phpml\Helper\Predictable; |
12
|
|
|
use Phpml\Helper\Trainable; |
13
|
|
|
use Phpml\Math\Statistic\Mean; |
14
|
|
|
use Phpml\Math\Statistic\StandardDeviation; |
15
|
|
|
use ReflectionClass; |
16
|
|
|
|
17
|
|
|
class AdaBoost implements Classifier |
18
|
|
|
{ |
19
|
|
|
use Predictable; |
20
|
|
|
use Trainable; |
21
|
|
|
|
22
|
|
|
/** |
23
|
|
|
* Actual labels given in the targets array |
24
|
|
|
* |
25
|
|
|
* @var array |
26
|
|
|
*/ |
27
|
|
|
protected $labels = []; |
28
|
|
|
|
29
|
|
|
/** |
30
|
|
|
* @var int |
31
|
|
|
*/ |
32
|
|
|
protected $sampleCount; |
33
|
|
|
|
34
|
|
|
/** |
35
|
|
|
* @var int |
36
|
|
|
*/ |
37
|
|
|
protected $featureCount; |
38
|
|
|
|
39
|
|
|
/** |
40
|
|
|
* Number of maximum iterations to be done |
41
|
|
|
* |
42
|
|
|
* @var int |
43
|
|
|
*/ |
44
|
|
|
protected $maxIterations; |
45
|
|
|
|
46
|
|
|
/** |
47
|
|
|
* Sample weights |
48
|
|
|
* |
49
|
|
|
* @var array |
50
|
|
|
*/ |
51
|
|
|
protected $weights = []; |
52
|
|
|
|
53
|
|
|
/** |
54
|
|
|
* List of selected 'weak' classifiers |
55
|
|
|
* |
56
|
|
|
* @var array |
57
|
|
|
*/ |
58
|
|
|
protected $classifiers = []; |
59
|
|
|
|
60
|
|
|
/** |
61
|
|
|
* Base classifier weights |
62
|
|
|
* |
63
|
|
|
* @var array |
64
|
|
|
*/ |
65
|
|
|
protected $alpha = []; |
66
|
|
|
|
67
|
|
|
/** |
68
|
|
|
* @var string |
69
|
|
|
*/ |
70
|
|
|
protected $baseClassifier = DecisionStump::class; |
71
|
|
|
|
72
|
|
|
/** |
73
|
|
|
* @var array |
74
|
|
|
*/ |
75
|
|
|
protected $classifierOptions = []; |
76
|
|
|
|
77
|
|
|
/** |
78
|
|
|
* ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to |
79
|
|
|
* improve classification performance of 'weak' classifiers such as |
80
|
|
|
* DecisionStump (default base classifier of AdaBoost). |
81
|
|
|
*/ |
82
|
|
|
public function __construct(int $maxIterations = 50) |
83
|
|
|
{ |
84
|
|
|
$this->maxIterations = $maxIterations; |
85
|
|
|
} |
86
|
|
|
|
87
|
|
|
/** |
88
|
|
|
* Sets the base classifier that will be used for boosting (default = DecisionStump) |
89
|
|
|
*/ |
90
|
|
|
public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = []): void |
91
|
|
|
{ |
92
|
|
|
$this->baseClassifier = $baseClassifier; |
93
|
|
|
$this->classifierOptions = $classifierOptions; |
94
|
|
|
} |
95
|
|
|
|
96
|
|
|
/** |
97
|
|
|
* @throws InvalidArgumentException |
98
|
|
|
*/ |
99
|
|
|
public function train(array $samples, array $targets): void |
100
|
|
|
{ |
101
|
|
|
// Initialize usual variables |
102
|
|
|
$this->labels = array_keys(array_count_values($targets)); |
103
|
|
|
if (count($this->labels) !== 2) { |
104
|
|
|
throw new InvalidArgumentException('AdaBoost is a binary classifier and can classify between two classes only'); |
105
|
|
|
} |
106
|
|
|
|
107
|
|
|
// Set all target values to either -1 or 1 |
108
|
|
|
$this->labels = [ |
109
|
|
|
1 => $this->labels[0], |
110
|
|
|
-1 => $this->labels[1], |
111
|
|
|
]; |
112
|
|
|
foreach ($targets as $target) { |
113
|
|
|
$this->targets[] = $target == $this->labels[1] ? 1 : -1; |
114
|
|
|
} |
115
|
|
|
|
116
|
|
|
$this->samples = array_merge($this->samples, $samples); |
117
|
|
|
$this->featureCount = count($samples[0]); |
118
|
|
|
$this->sampleCount = count($this->samples); |
119
|
|
|
|
120
|
|
|
// Initialize AdaBoost parameters |
121
|
|
|
$this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount); |
122
|
|
|
$this->classifiers = []; |
123
|
|
|
$this->alpha = []; |
124
|
|
|
|
125
|
|
|
// Execute the algorithm for a maximum number of iterations |
126
|
|
|
$currIter = 0; |
127
|
|
|
while ($this->maxIterations > $currIter++) { |
128
|
|
|
// Determine the best 'weak' classifier based on current weights |
129
|
|
|
$classifier = $this->getBestClassifier(); |
130
|
|
|
$errorRate = $this->evaluateClassifier($classifier); |
131
|
|
|
|
132
|
|
|
// Update alpha & weight values at each iteration |
133
|
|
|
$alpha = $this->calculateAlpha($errorRate); |
134
|
|
|
$this->updateWeights($classifier, $alpha); |
135
|
|
|
|
136
|
|
|
$this->classifiers[] = $classifier; |
137
|
|
|
$this->alpha[] = $alpha; |
138
|
|
|
} |
139
|
|
|
} |
140
|
|
|
|
141
|
|
|
/** |
142
|
|
|
* @return mixed |
143
|
|
|
*/ |
144
|
|
|
public function predictSample(array $sample) |
145
|
|
|
{ |
146
|
|
|
$sum = 0; |
147
|
|
|
foreach ($this->alpha as $index => $alpha) { |
148
|
|
|
$h = $this->classifiers[$index]->predict($sample); |
149
|
|
|
$sum += $h * $alpha; |
150
|
|
|
} |
151
|
|
|
|
152
|
|
|
return $this->labels[$sum > 0 ? 1 : -1]; |
153
|
|
|
} |
154
|
|
|
|
155
|
|
|
/** |
156
|
|
|
* Returns the classifier with the lowest error rate with the |
157
|
|
|
* consideration of current sample weights |
158
|
|
|
*/ |
159
|
|
|
protected function getBestClassifier(): Classifier |
160
|
|
|
{ |
161
|
|
|
$ref = new ReflectionClass($this->baseClassifier); |
162
|
|
|
/** @var Classifier $classifier */ |
163
|
|
|
$classifier = count($this->classifierOptions) === 0 ? $ref->newInstance() : $ref->newInstanceArgs($this->classifierOptions); |
164
|
|
|
|
165
|
|
|
if ($classifier instanceof WeightedClassifier) { |
166
|
|
|
$classifier->setSampleWeights($this->weights); |
167
|
|
|
$classifier->train($this->samples, $this->targets); |
168
|
|
|
} else { |
169
|
|
|
[$samples, $targets] = $this->resample(); |
170
|
|
|
$classifier->train($samples, $targets); |
171
|
|
|
} |
172
|
|
|
|
173
|
|
|
return $classifier; |
174
|
|
|
} |
175
|
|
|
|
176
|
|
|
/** |
177
|
|
|
* Resamples the dataset in accordance with the weights and |
178
|
|
|
* returns the new dataset |
179
|
|
|
*/ |
180
|
|
|
protected function resample(): array |
181
|
|
|
{ |
182
|
|
|
$weights = $this->weights; |
183
|
|
|
$std = StandardDeviation::population($weights); |
184
|
|
|
$mean = Mean::arithmetic($weights); |
185
|
|
|
$min = min($weights); |
186
|
|
|
$minZ = (int) round(($min - $mean) / $std); |
187
|
|
|
|
188
|
|
|
$samples = []; |
189
|
|
|
$targets = []; |
190
|
|
|
foreach ($weights as $index => $weight) { |
191
|
|
|
$z = (int) round(($weight - $mean) / $std) - $minZ + 1; |
192
|
|
|
for ($i = 0; $i < $z; ++$i) { |
193
|
|
|
if (random_int(0, 1) == 0) { |
194
|
|
|
continue; |
195
|
|
|
} |
196
|
|
|
|
197
|
|
|
$samples[] = $this->samples[$index]; |
198
|
|
|
$targets[] = $this->targets[$index]; |
199
|
|
|
} |
200
|
|
|
} |
201
|
|
|
|
202
|
|
|
return [$samples, $targets]; |
203
|
|
|
} |
204
|
|
|
|
205
|
|
|
/** |
206
|
|
|
* Evaluates the classifier and returns the classification error rate |
207
|
|
|
*/ |
208
|
|
|
protected function evaluateClassifier(Classifier $classifier): float |
209
|
|
|
{ |
210
|
|
|
$total = (float) array_sum($this->weights); |
211
|
|
|
$wrong = 0; |
212
|
|
|
foreach ($this->samples as $index => $sample) { |
213
|
|
|
$predicted = $classifier->predict($sample); |
214
|
|
|
if ($predicted != $this->targets[$index]) { |
215
|
|
|
$wrong += $this->weights[$index]; |
216
|
|
|
} |
217
|
|
|
} |
218
|
|
|
|
219
|
|
|
return $wrong / $total; |
220
|
|
|
} |
221
|
|
|
|
222
|
|
|
/** |
223
|
|
|
* Calculates alpha of a classifier |
224
|
|
|
*/ |
225
|
|
|
protected function calculateAlpha(float $errorRate): float |
226
|
|
|
{ |
227
|
|
|
if ($errorRate == 0) { |
228
|
|
|
$errorRate = 1e-10; |
229
|
|
|
} |
230
|
|
|
|
231
|
|
|
return 0.5 * log((1 - $errorRate) / $errorRate); |
232
|
|
|
} |
233
|
|
|
|
234
|
|
|
/** |
235
|
|
|
* Updates the sample weights |
236
|
|
|
*/ |
237
|
|
|
protected function updateWeights(Classifier $classifier, float $alpha): void |
238
|
|
|
{ |
239
|
|
|
$sumOfWeights = array_sum($this->weights); |
240
|
|
|
$weightsT1 = []; |
241
|
|
|
foreach ($this->weights as $index => $weight) { |
242
|
|
|
$desired = $this->targets[$index]; |
243
|
|
|
$output = $classifier->predict($this->samples[$index]); |
244
|
|
|
|
245
|
|
|
$weight *= exp(-$alpha * $desired * $output) / $sumOfWeights; |
246
|
|
|
|
247
|
|
|
$weightsT1[] = $weight; |
248
|
|
|
} |
249
|
|
|
|
250
|
|
|
$this->weights = $weightsT1; |
251
|
|
|
} |
252
|
|
|
} |
253
|
|
|
|