These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more
1 | <?php |
||
2 | |||
3 | declare(strict_types=1); |
||
4 | |||
5 | namespace Phpml\Classification\Ensemble; |
||
6 | |||
7 | use Phpml\Classification\Linear\DecisionStump; |
||
8 | use Phpml\Classification\WeightedClassifier; |
||
9 | use Phpml\Math\Statistic\Mean; |
||
10 | use Phpml\Math\Statistic\StandardDeviation; |
||
11 | use Phpml\Classification\Classifier; |
||
12 | use Phpml\Helper\Predictable; |
||
13 | use Phpml\Helper\Trainable; |
||
14 | |||
15 | class AdaBoost implements Classifier |
||
16 | { |
||
17 | use Predictable, Trainable; |
||
18 | |||
19 | /** |
||
20 | * Actual labels given in the targets array |
||
21 | * @var array |
||
22 | */ |
||
23 | protected $labels = []; |
||
24 | |||
25 | /** |
||
26 | * @var int |
||
27 | */ |
||
28 | protected $sampleCount; |
||
29 | |||
30 | /** |
||
31 | * @var int |
||
32 | */ |
||
33 | protected $featureCount; |
||
34 | |||
35 | /** |
||
36 | * Number of maximum iterations to be done |
||
37 | * |
||
38 | * @var int |
||
39 | */ |
||
40 | protected $maxIterations; |
||
41 | |||
42 | /** |
||
43 | * Sample weights |
||
44 | * |
||
45 | * @var array |
||
46 | */ |
||
47 | protected $weights = []; |
||
48 | |||
49 | /** |
||
50 | * List of selected 'weak' classifiers |
||
51 | * |
||
52 | * @var array |
||
53 | */ |
||
54 | protected $classifiers = []; |
||
55 | |||
56 | /** |
||
57 | * Base classifier weights |
||
58 | * |
||
59 | * @var array |
||
60 | */ |
||
61 | protected $alpha = []; |
||
62 | |||
63 | /** |
||
64 | * @var string |
||
65 | */ |
||
66 | protected $baseClassifier = DecisionStump::class; |
||
67 | |||
68 | /** |
||
69 | * @var array |
||
70 | */ |
||
71 | protected $classifierOptions = []; |
||
72 | |||
73 | /** |
||
74 | * ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to |
||
75 | * improve classification performance of 'weak' classifiers such as |
||
76 | * DecisionStump (default base classifier of AdaBoost). |
||
77 | * |
||
78 | * @param int $maxIterations |
||
79 | */ |
||
80 | public function __construct(int $maxIterations = 50) |
||
81 | { |
||
82 | $this->maxIterations = $maxIterations; |
||
83 | } |
||
84 | |||
85 | /** |
||
86 | * Sets the base classifier that will be used for boosting (default = DecisionStump) |
||
87 | * |
||
88 | * @param string $baseClassifier |
||
89 | * @param array $classifierOptions |
||
90 | */ |
||
91 | public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = []) |
||
92 | { |
||
93 | $this->baseClassifier = $baseClassifier; |
||
94 | $this->classifierOptions = $classifierOptions; |
||
95 | } |
||
96 | |||
97 | /** |
||
98 | * @param array $samples |
||
99 | * @param array $targets |
||
100 | * |
||
101 | * @throws \Exception |
||
102 | */ |
||
103 | public function train(array $samples, array $targets) |
||
104 | { |
||
105 | // Initialize usual variables |
||
106 | $this->labels = array_keys(array_count_values($targets)); |
||
107 | if (count($this->labels) != 2) { |
||
108 | throw new \Exception("AdaBoost is a binary classifier and can classify between two classes only"); |
||
109 | } |
||
110 | |||
111 | // Set all target values to either -1 or 1 |
||
112 | $this->labels = [1 => $this->labels[0], -1 => $this->labels[1]]; |
||
113 | foreach ($targets as $target) { |
||
114 | $this->targets[] = $target == $this->labels[1] ? 1 : -1; |
||
115 | } |
||
116 | |||
117 | $this->samples = array_merge($this->samples, $samples); |
||
118 | $this->featureCount = count($samples[0]); |
||
119 | $this->sampleCount = count($this->samples); |
||
120 | |||
121 | // Initialize AdaBoost parameters |
||
122 | $this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount); |
||
123 | $this->classifiers = []; |
||
124 | $this->alpha = []; |
||
125 | |||
126 | // Execute the algorithm for a maximum number of iterations |
||
127 | $currIter = 0; |
||
128 | while ($this->maxIterations > $currIter++) { |
||
129 | // Determine the best 'weak' classifier based on current weights |
||
130 | $classifier = $this->getBestClassifier(); |
||
131 | $errorRate = $this->evaluateClassifier($classifier); |
||
132 | |||
133 | // Update alpha & weight values at each iteration |
||
134 | $alpha = $this->calculateAlpha($errorRate); |
||
135 | $this->updateWeights($classifier, $alpha); |
||
136 | |||
137 | $this->classifiers[] = $classifier; |
||
138 | $this->alpha[] = $alpha; |
||
139 | } |
||
140 | } |
||
141 | |||
142 | /** |
||
143 | * Returns the classifier with the lowest error rate with the |
||
144 | * consideration of current sample weights |
||
145 | * |
||
146 | * @return Classifier |
||
147 | */ |
||
148 | protected function getBestClassifier() |
||
149 | { |
||
150 | $ref = new \ReflectionClass($this->baseClassifier); |
||
151 | if ($this->classifierOptions) { |
||
0 ignored issues
–
show
|
|||
152 | $classifier = $ref->newInstanceArgs($this->classifierOptions); |
||
153 | } else { |
||
154 | $classifier = $ref->newInstance(); |
||
155 | } |
||
156 | |||
157 | if (is_subclass_of($classifier, WeightedClassifier::class)) { |
||
0 ignored issues
–
show
|
|||
158 | $classifier->setSampleWeights($this->weights); |
||
159 | $classifier->train($this->samples, $this->targets); |
||
160 | } else { |
||
161 | list($samples, $targets) = $this->resample(); |
||
162 | $classifier->train($samples, $targets); |
||
163 | } |
||
164 | |||
165 | return $classifier; |
||
166 | } |
||
167 | |||
168 | /** |
||
169 | * Resamples the dataset in accordance with the weights and |
||
170 | * returns the new dataset |
||
171 | * |
||
172 | * @return array |
||
173 | */ |
||
174 | protected function resample() |
||
175 | { |
||
176 | $weights = $this->weights; |
||
177 | $std = StandardDeviation::population($weights); |
||
178 | $mean= Mean::arithmetic($weights); |
||
179 | $min = min($weights); |
||
180 | $minZ= (int)round(($min - $mean) / $std); |
||
181 | |||
182 | $samples = []; |
||
183 | $targets = []; |
||
184 | foreach ($weights as $index => $weight) { |
||
185 | $z = (int)round(($weight - $mean) / $std) - $minZ + 1; |
||
186 | View Code Duplication | for ($i = 0; $i < $z; ++$i) { |
|
0 ignored issues
–
show
This code seems to be duplicated across your project.
Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation. You can also find more detailed suggestions in the “Code” section of your repository.
Loading history...
|
|||
187 | if (rand(0, 1) == 0) { |
||
188 | continue; |
||
189 | } |
||
190 | $samples[] = $this->samples[$index]; |
||
191 | $targets[] = $this->targets[$index]; |
||
192 | } |
||
193 | } |
||
194 | |||
195 | return [$samples, $targets]; |
||
196 | } |
||
197 | |||
198 | /** |
||
199 | * Evaluates the classifier and returns the classification error rate |
||
200 | * |
||
201 | * @param Classifier $classifier |
||
202 | * |
||
203 | * @return float |
||
204 | */ |
||
205 | protected function evaluateClassifier(Classifier $classifier) |
||
206 | { |
||
207 | $total = (float) array_sum($this->weights); |
||
208 | $wrong = 0; |
||
209 | foreach ($this->samples as $index => $sample) { |
||
210 | $predicted = $classifier->predict($sample); |
||
211 | if ($predicted != $this->targets[$index]) { |
||
212 | $wrong += $this->weights[$index]; |
||
213 | } |
||
214 | } |
||
215 | |||
216 | return $wrong / $total; |
||
217 | } |
||
218 | |||
219 | /** |
||
220 | * Calculates alpha of a classifier |
||
221 | * |
||
222 | * @param float $errorRate |
||
223 | * @return float |
||
224 | */ |
||
225 | protected function calculateAlpha(float $errorRate) |
||
226 | { |
||
227 | if ($errorRate == 0) { |
||
228 | $errorRate = 1e-10; |
||
229 | } |
||
230 | return 0.5 * log((1 - $errorRate) / $errorRate); |
||
231 | } |
||
232 | |||
233 | /** |
||
234 | * Updates the sample weights |
||
235 | * |
||
236 | * @param Classifier $classifier |
||
237 | * @param float $alpha |
||
238 | */ |
||
239 | protected function updateWeights(Classifier $classifier, float $alpha) |
||
240 | { |
||
241 | $sumOfWeights = array_sum($this->weights); |
||
242 | $weightsT1 = []; |
||
243 | foreach ($this->weights as $index => $weight) { |
||
244 | $desired = $this->targets[$index]; |
||
245 | $output = $classifier->predict($this->samples[$index]); |
||
246 | |||
247 | $weight *= exp(-$alpha * $desired * $output) / $sumOfWeights; |
||
248 | |||
249 | $weightsT1[] = $weight; |
||
250 | } |
||
251 | |||
252 | $this->weights = $weightsT1; |
||
253 | } |
||
254 | |||
255 | /** |
||
256 | * @param array $sample |
||
257 | * @return mixed |
||
258 | */ |
||
259 | public function predictSample(array $sample) |
||
260 | { |
||
261 | $sum = 0; |
||
262 | foreach ($this->alpha as $index => $alpha) { |
||
263 | $h = $this->classifiers[$index]->predict($sample); |
||
264 | $sum += $h * $alpha; |
||
265 | } |
||
266 | |||
267 | return $this->labels[ $sum > 0 ? 1 : -1]; |
||
268 | } |
||
269 | } |
||
270 |
This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.
Consider making the comparison explicit by using
empty(..)
or! empty(...)
instead.