1
|
|
|
<?php |
2
|
|
|
|
3
|
|
|
declare(strict_types=1); |
4
|
|
|
|
5
|
|
|
namespace Phpml\Classification\Ensemble; |
6
|
|
|
|
7
|
|
|
use Phpml\Classification\Linear\DecisionStump; |
8
|
|
|
use Phpml\Classification\Classifier; |
9
|
|
|
use Phpml\Helper\Predictable; |
10
|
|
|
use Phpml\Helper\Trainable; |
11
|
|
|
|
12
|
|
|
class AdaBoost implements Classifier |
13
|
|
|
{ |
14
|
|
|
use Predictable, Trainable; |
15
|
|
|
|
16
|
|
|
/** |
17
|
|
|
* Actual labels given in the targets array |
18
|
|
|
* @var array |
19
|
|
|
*/ |
20
|
|
|
protected $labels = []; |
21
|
|
|
|
22
|
|
|
/** |
23
|
|
|
* @var int |
24
|
|
|
*/ |
25
|
|
|
protected $sampleCount; |
26
|
|
|
|
27
|
|
|
/** |
28
|
|
|
* @var int |
29
|
|
|
*/ |
30
|
|
|
protected $featureCount; |
31
|
|
|
|
32
|
|
|
/** |
33
|
|
|
* Number of maximum iterations to be done |
34
|
|
|
* |
35
|
|
|
* @var int |
36
|
|
|
*/ |
37
|
|
|
protected $maxIterations; |
38
|
|
|
|
39
|
|
|
/** |
40
|
|
|
* Sample weights |
41
|
|
|
* |
42
|
|
|
* @var array |
43
|
|
|
*/ |
44
|
|
|
protected $weights = []; |
45
|
|
|
|
46
|
|
|
/** |
47
|
|
|
* Base classifiers |
48
|
|
|
* |
49
|
|
|
* @var array |
50
|
|
|
*/ |
51
|
|
|
protected $classifiers = []; |
52
|
|
|
|
53
|
|
|
/** |
54
|
|
|
* Base classifier weights |
55
|
|
|
* |
56
|
|
|
* @var array |
57
|
|
|
*/ |
58
|
|
|
protected $alpha = []; |
59
|
|
|
|
60
|
|
|
/** |
61
|
|
|
* ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to |
62
|
|
|
* improve classification performance of 'weak' classifiers such as |
63
|
|
|
* DecisionStump (default base classifier of AdaBoost). |
64
|
|
|
* |
65
|
|
|
*/ |
66
|
|
|
public function __construct(int $maxIterations = 30) |
67
|
|
|
{ |
68
|
|
|
$this->maxIterations = $maxIterations; |
69
|
|
|
} |
70
|
|
|
|
71
|
|
|
/** |
72
|
|
|
* @param array $samples |
73
|
|
|
* @param array $targets |
74
|
|
|
*/ |
75
|
|
|
public function train(array $samples, array $targets) |
76
|
|
|
{ |
77
|
|
|
// Initialize usual variables |
78
|
|
|
$this->labels = array_keys(array_count_values($targets)); |
79
|
|
|
if (count($this->labels) != 2) { |
80
|
|
|
throw new \Exception("AdaBoost is a binary classifier and can only classify between two classes"); |
81
|
|
|
} |
82
|
|
|
|
83
|
|
|
// Set all target values to either -1 or 1 |
84
|
|
|
$this->labels = [1 => $this->labels[0], -1 => $this->labels[1]]; |
85
|
|
View Code Duplication |
foreach ($targets as $target) { |
|
|
|
|
86
|
|
|
$this->targets[] = $target == $this->labels[1] ? 1 : -1; |
87
|
|
|
} |
88
|
|
|
|
89
|
|
|
$this->samples = array_merge($this->samples, $samples); |
90
|
|
|
$this->featureCount = count($samples[0]); |
91
|
|
|
$this->sampleCount = count($this->samples); |
92
|
|
|
|
93
|
|
|
// Initialize AdaBoost parameters |
94
|
|
|
$this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount); |
95
|
|
|
$this->classifiers = []; |
96
|
|
|
$this->alpha = []; |
97
|
|
|
|
98
|
|
|
// Execute the algorithm for a maximum number of iterations |
99
|
|
|
$currIter = 0; |
100
|
|
|
while ($this->maxIterations > $currIter++) { |
101
|
|
|
// Determine the best 'weak' classifier based on current weights |
102
|
|
|
// and update alpha & weight values at each iteration |
103
|
|
|
list($classifier, $errorRate) = $this->getBestClassifier(); |
104
|
|
|
$alpha = $this->calculateAlpha($errorRate); |
105
|
|
|
$this->updateWeights($classifier, $alpha); |
|
|
|
|
106
|
|
|
|
107
|
|
|
$this->classifiers[] = $classifier; |
108
|
|
|
$this->alpha[] = $alpha; |
109
|
|
|
} |
110
|
|
|
} |
111
|
|
|
|
112
|
|
|
/** |
113
|
|
|
* Returns the classifier with the lowest error rate with the |
114
|
|
|
* consideration of current sample weights |
115
|
|
|
* |
116
|
|
|
* @return Classifier |
117
|
|
|
*/ |
118
|
|
|
protected function getBestClassifier() |
119
|
|
|
{ |
120
|
|
|
// This method works only for "DecisionStump" classifier, for now. |
121
|
|
|
// As a future task, it will be generalized enough to work with other |
122
|
|
|
// classifiers as well |
123
|
|
|
$minErrorRate = 1.0; |
124
|
|
|
$bestClassifier = null; |
125
|
|
|
for ($i=0; $i < $this->featureCount; $i++) { |
126
|
|
|
$stump = new DecisionStump($i); |
127
|
|
|
$stump->setSampleWeights($this->weights); |
128
|
|
|
$stump->train($this->samples, $this->targets); |
129
|
|
|
|
130
|
|
|
$errorRate = $stump->getTrainingErrorRate(); |
131
|
|
|
if ($errorRate < $minErrorRate) { |
132
|
|
|
$bestClassifier = $stump; |
133
|
|
|
$minErrorRate = $errorRate; |
134
|
|
|
} |
135
|
|
|
} |
136
|
|
|
|
137
|
|
|
return [$bestClassifier, $minErrorRate]; |
138
|
|
|
} |
139
|
|
|
|
140
|
|
|
/** |
141
|
|
|
* Calculates alpha of a classifier |
142
|
|
|
* |
143
|
|
|
* @param float $errorRate |
144
|
|
|
* @return float |
145
|
|
|
*/ |
146
|
|
|
protected function calculateAlpha(float $errorRate) |
147
|
|
|
{ |
148
|
|
|
if ($errorRate == 0) { |
149
|
|
|
$errorRate = 1e-10; |
150
|
|
|
} |
151
|
|
|
return 0.5 * log((1 - $errorRate) / $errorRate); |
152
|
|
|
} |
153
|
|
|
|
154
|
|
|
/** |
155
|
|
|
* Updates the sample weights |
156
|
|
|
* |
157
|
|
|
* @param DecisionStump $classifier |
158
|
|
|
* @param float $alpha |
159
|
|
|
*/ |
160
|
|
|
protected function updateWeights(DecisionStump $classifier, float $alpha) |
161
|
|
|
{ |
162
|
|
|
$sumOfWeights = array_sum($this->weights); |
163
|
|
|
$weightsT1 = []; |
164
|
|
|
foreach ($this->weights as $index => $weight) { |
165
|
|
|
$desired = $this->targets[$index]; |
166
|
|
|
$output = $classifier->predict($this->samples[$index]); |
167
|
|
|
|
168
|
|
|
$weight *= exp(-$alpha * $desired * $output) / $sumOfWeights; |
169
|
|
|
|
170
|
|
|
$weightsT1[] = $weight; |
171
|
|
|
} |
172
|
|
|
|
173
|
|
|
$this->weights = $weightsT1; |
174
|
|
|
} |
175
|
|
|
|
176
|
|
|
/** |
177
|
|
|
* @param array $sample |
178
|
|
|
* @return mixed |
179
|
|
|
*/ |
180
|
|
|
public function predictSample(array $sample) |
181
|
|
|
{ |
182
|
|
|
$sum = 0; |
183
|
|
|
foreach ($this->alpha as $index => $alpha) { |
184
|
|
|
$h = $this->classifiers[$index]->predict($sample); |
185
|
|
|
$sum += $h * $alpha; |
186
|
|
|
} |
187
|
|
|
|
188
|
|
|
return $this->labels[ $sum > 0 ? 1 : -1]; |
189
|
|
|
} |
190
|
|
|
} |
191
|
|
|
|
Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.
You can also find more detailed suggestions in the “Code” section of your repository.