These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more
1 | <?php |
||
2 | |||
3 | declare(strict_types=1); |
||
4 | |||
5 | namespace Phpml\Classification\Ensemble; |
||
6 | |||
7 | use Exception; |
||
8 | use Phpml\Classification\Classifier; |
||
9 | use Phpml\Classification\Linear\DecisionStump; |
||
10 | use Phpml\Classification\WeightedClassifier; |
||
11 | use Phpml\Helper\Predictable; |
||
12 | use Phpml\Helper\Trainable; |
||
13 | use Phpml\Math\Statistic\Mean; |
||
14 | use Phpml\Math\Statistic\StandardDeviation; |
||
15 | use ReflectionClass; |
||
16 | |||
17 | class AdaBoost implements Classifier |
||
18 | { |
||
19 | use Predictable, Trainable; |
||
20 | |||
21 | /** |
||
22 | * Actual labels given in the targets array |
||
23 | * |
||
24 | * @var array |
||
25 | */ |
||
26 | protected $labels = []; |
||
27 | |||
28 | /** |
||
29 | * @var int |
||
30 | */ |
||
31 | protected $sampleCount; |
||
32 | |||
33 | /** |
||
34 | * @var int |
||
35 | */ |
||
36 | protected $featureCount; |
||
37 | |||
38 | /** |
||
39 | * Number of maximum iterations to be done |
||
40 | * |
||
41 | * @var int |
||
42 | */ |
||
43 | protected $maxIterations; |
||
44 | |||
45 | /** |
||
46 | * Sample weights |
||
47 | * |
||
48 | * @var array |
||
49 | */ |
||
50 | protected $weights = []; |
||
51 | |||
52 | /** |
||
53 | * List of selected 'weak' classifiers |
||
54 | * |
||
55 | * @var array |
||
56 | */ |
||
57 | protected $classifiers = []; |
||
58 | |||
59 | /** |
||
60 | * Base classifier weights |
||
61 | * |
||
62 | * @var array |
||
63 | */ |
||
64 | protected $alpha = []; |
||
65 | |||
66 | /** |
||
67 | * @var string |
||
68 | */ |
||
69 | protected $baseClassifier = DecisionStump::class; |
||
70 | |||
71 | /** |
||
72 | * @var array |
||
73 | */ |
||
74 | protected $classifierOptions = []; |
||
75 | |||
76 | /** |
||
77 | * ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to |
||
78 | * improve classification performance of 'weak' classifiers such as |
||
79 | * DecisionStump (default base classifier of AdaBoost). |
||
80 | */ |
||
81 | public function __construct(int $maxIterations = 50) |
||
82 | { |
||
83 | $this->maxIterations = $maxIterations; |
||
84 | } |
||
85 | |||
86 | /** |
||
87 | * Sets the base classifier that will be used for boosting (default = DecisionStump) |
||
88 | */ |
||
89 | public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = []): void |
||
90 | { |
||
91 | $this->baseClassifier = $baseClassifier; |
||
92 | $this->classifierOptions = $classifierOptions; |
||
93 | } |
||
94 | |||
95 | /** |
||
96 | * @throws \Exception |
||
97 | */ |
||
98 | public function train(array $samples, array $targets): void |
||
99 | { |
||
100 | // Initialize usual variables |
||
101 | $this->labels = array_keys(array_count_values($targets)); |
||
102 | if (count($this->labels) != 2) { |
||
103 | throw new Exception('AdaBoost is a binary classifier and can classify between two classes only'); |
||
104 | } |
||
105 | |||
106 | // Set all target values to either -1 or 1 |
||
107 | $this->labels = [ |
||
108 | 1 => $this->labels[0], |
||
109 | -1 => $this->labels[1], |
||
110 | ]; |
||
111 | foreach ($targets as $target) { |
||
112 | $this->targets[] = $target == $this->labels[1] ? 1 : -1; |
||
113 | } |
||
114 | |||
115 | $this->samples = array_merge($this->samples, $samples); |
||
116 | $this->featureCount = count($samples[0]); |
||
117 | $this->sampleCount = count($this->samples); |
||
118 | |||
119 | // Initialize AdaBoost parameters |
||
120 | $this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount); |
||
121 | $this->classifiers = []; |
||
122 | $this->alpha = []; |
||
123 | |||
124 | // Execute the algorithm for a maximum number of iterations |
||
125 | $currIter = 0; |
||
126 | while ($this->maxIterations > $currIter++) { |
||
127 | // Determine the best 'weak' classifier based on current weights |
||
128 | $classifier = $this->getBestClassifier(); |
||
129 | $errorRate = $this->evaluateClassifier($classifier); |
||
130 | |||
131 | // Update alpha & weight values at each iteration |
||
132 | $alpha = $this->calculateAlpha($errorRate); |
||
133 | $this->updateWeights($classifier, $alpha); |
||
134 | |||
135 | $this->classifiers[] = $classifier; |
||
136 | $this->alpha[] = $alpha; |
||
137 | } |
||
138 | } |
||
139 | |||
140 | /** |
||
141 | * @return mixed |
||
142 | */ |
||
143 | public function predictSample(array $sample) |
||
144 | { |
||
145 | $sum = 0; |
||
146 | foreach ($this->alpha as $index => $alpha) { |
||
147 | $h = $this->classifiers[$index]->predict($sample); |
||
148 | $sum += $h * $alpha; |
||
149 | } |
||
150 | |||
151 | return $this->labels[$sum > 0 ? 1 : -1]; |
||
152 | } |
||
153 | |||
154 | /** |
||
155 | * Returns the classifier with the lowest error rate with the |
||
156 | * consideration of current sample weights |
||
157 | */ |
||
158 | protected function getBestClassifier(): Classifier |
||
159 | { |
||
160 | $ref = new ReflectionClass($this->baseClassifier); |
||
161 | if ($this->classifierOptions) { |
||
0 ignored issues
–
show
|
|||
162 | $classifier = $ref->newInstanceArgs($this->classifierOptions); |
||
163 | } else { |
||
164 | $classifier = $ref->newInstance(); |
||
165 | } |
||
166 | |||
167 | if (is_subclass_of($classifier, WeightedClassifier::class)) { |
||
168 | $classifier->setSampleWeights($this->weights); |
||
169 | $classifier->train($this->samples, $this->targets); |
||
170 | } else { |
||
171 | [$samples, $targets] = $this->resample(); |
||
172 | $classifier->train($samples, $targets); |
||
173 | } |
||
174 | |||
175 | return $classifier; |
||
176 | } |
||
177 | |||
178 | /** |
||
179 | * Resamples the dataset in accordance with the weights and |
||
180 | * returns the new dataset |
||
181 | */ |
||
182 | protected function resample(): array |
||
183 | { |
||
184 | $weights = $this->weights; |
||
185 | $std = StandardDeviation::population($weights); |
||
186 | $mean = Mean::arithmetic($weights); |
||
187 | $min = min($weights); |
||
188 | $minZ = (int) round(($min - $mean) / $std); |
||
189 | |||
190 | $samples = []; |
||
191 | $targets = []; |
||
192 | foreach ($weights as $index => $weight) { |
||
193 | $z = (int) round(($weight - $mean) / $std) - $minZ + 1; |
||
194 | for ($i = 0; $i < $z; ++$i) { |
||
195 | if (random_int(0, 1) == 0) { |
||
196 | continue; |
||
197 | } |
||
198 | |||
199 | $samples[] = $this->samples[$index]; |
||
200 | $targets[] = $this->targets[$index]; |
||
201 | } |
||
202 | } |
||
203 | |||
204 | return [$samples, $targets]; |
||
205 | } |
||
206 | |||
207 | /** |
||
208 | * Evaluates the classifier and returns the classification error rate |
||
209 | */ |
||
210 | protected function evaluateClassifier(Classifier $classifier): float |
||
211 | { |
||
212 | $total = (float) array_sum($this->weights); |
||
213 | $wrong = 0; |
||
214 | foreach ($this->samples as $index => $sample) { |
||
215 | $predicted = $classifier->predict($sample); |
||
216 | if ($predicted != $this->targets[$index]) { |
||
217 | $wrong += $this->weights[$index]; |
||
218 | } |
||
219 | } |
||
220 | |||
221 | return $wrong / $total; |
||
222 | } |
||
223 | |||
224 | /** |
||
225 | * Calculates alpha of a classifier |
||
226 | */ |
||
227 | protected function calculateAlpha(float $errorRate): float |
||
228 | { |
||
229 | if ($errorRate == 0) { |
||
230 | $errorRate = 1e-10; |
||
231 | } |
||
232 | |||
233 | return 0.5 * log((1 - $errorRate) / $errorRate); |
||
234 | } |
||
235 | |||
236 | /** |
||
237 | * Updates the sample weights |
||
238 | */ |
||
239 | protected function updateWeights(Classifier $classifier, float $alpha): void |
||
240 | { |
||
241 | $sumOfWeights = array_sum($this->weights); |
||
242 | $weightsT1 = []; |
||
243 | foreach ($this->weights as $index => $weight) { |
||
244 | $desired = $this->targets[$index]; |
||
245 | $output = $classifier->predict($this->samples[$index]); |
||
246 | |||
247 | $weight *= exp(-$alpha * $desired * $output) / $sumOfWeights; |
||
248 | |||
249 | $weightsT1[] = $weight; |
||
250 | } |
||
251 | |||
252 | $this->weights = $weightsT1; |
||
253 | } |
||
254 | } |
||
255 |
This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.
Consider making the comparison explicit by using
empty(..)
or! empty(...)
instead.