AdaBoost   A
last analyzed

Complexity

Total Complexity 24

Size/Duplication

Total Lines 234
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 24
eloc 84
dl 0
loc 234
rs 10
c 0
b 0
f 0

9 Methods

Rating   Name   Duplication   Size   Complexity  
A getBestClassifier() 0 15 3
A train() 0 39 5
A __construct() 0 3 1
A evaluateClassifier() 0 12 3
A updateWeights() 0 14 2
A setBaseClassifier() 0 4 1
A calculateAlpha() 0 7 2
A resample() 0 23 4
A predictSample() 0 9 3
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Ensemble;
6
7
use Phpml\Classification\Classifier;
8
use Phpml\Classification\Linear\DecisionStump;
9
use Phpml\Classification\WeightedClassifier;
10
use Phpml\Exception\InvalidArgumentException;
11
use Phpml\Helper\Predictable;
12
use Phpml\Helper\Trainable;
13
use Phpml\Math\Statistic\Mean;
14
use Phpml\Math\Statistic\StandardDeviation;
15
use ReflectionClass;
16
17
class AdaBoost implements Classifier
18
{
19
    use Predictable;
20
    use Trainable;
21
22
    /**
23
     * Actual labels given in the targets array
24
     *
25
     * @var array
26
     */
27
    protected $labels = [];
28
29
    /**
30
     * @var int
31
     */
32
    protected $sampleCount;
33
34
    /**
35
     * @var int
36
     */
37
    protected $featureCount;
38
39
    /**
40
     * Number of maximum iterations to be done
41
     *
42
     * @var int
43
     */
44
    protected $maxIterations;
45
46
    /**
47
     * Sample weights
48
     *
49
     * @var array
50
     */
51
    protected $weights = [];
52
53
    /**
54
     * List of selected 'weak' classifiers
55
     *
56
     * @var array
57
     */
58
    protected $classifiers = [];
59
60
    /**
61
     * Base classifier weights
62
     *
63
     * @var array
64
     */
65
    protected $alpha = [];
66
67
    /**
68
     * @var string
69
     */
70
    protected $baseClassifier = DecisionStump::class;
71
72
    /**
73
     * @var array
74
     */
75
    protected $classifierOptions = [];
76
77
    /**
78
     * ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
79
     * improve classification performance of 'weak' classifiers such as
80
     * DecisionStump (default base classifier of AdaBoost).
81
     */
82
    public function __construct(int $maxIterations = 50)
83
    {
84
        $this->maxIterations = $maxIterations;
85
    }
86
87
    /**
88
     * Sets the base classifier that will be used for boosting (default = DecisionStump)
89
     */
90
    public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = []): void
91
    {
92
        $this->baseClassifier = $baseClassifier;
93
        $this->classifierOptions = $classifierOptions;
94
    }
95
96
    /**
97
     * @throws InvalidArgumentException
98
     */
99
    public function train(array $samples, array $targets): void
100
    {
101
        // Initialize usual variables
102
        $this->labels = array_keys(array_count_values($targets));
103
        if (count($this->labels) !== 2) {
104
            throw new InvalidArgumentException('AdaBoost is a binary classifier and can classify between two classes only');
105
        }
106
107
        // Set all target values to either -1 or 1
108
        $this->labels = [
109
            1 => $this->labels[0],
110
            -1 => $this->labels[1],
111
        ];
112
        foreach ($targets as $target) {
113
            $this->targets[] = $target == $this->labels[1] ? 1 : -1;
114
        }
115
116
        $this->samples = array_merge($this->samples, $samples);
117
        $this->featureCount = count($samples[0]);
118
        $this->sampleCount = count($this->samples);
119
120
        // Initialize AdaBoost parameters
121
        $this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount);
122
        $this->classifiers = [];
123
        $this->alpha = [];
124
125
        // Execute the algorithm for a maximum number of iterations
126
        $currIter = 0;
127
        while ($this->maxIterations > $currIter++) {
128
            // Determine the best 'weak' classifier based on current weights
129
            $classifier = $this->getBestClassifier();
130
            $errorRate = $this->evaluateClassifier($classifier);
131
132
            // Update alpha & weight values at each iteration
133
            $alpha = $this->calculateAlpha($errorRate);
134
            $this->updateWeights($classifier, $alpha);
135
136
            $this->classifiers[] = $classifier;
137
            $this->alpha[] = $alpha;
138
        }
139
    }
140
141
    /**
142
     * @return mixed
143
     */
144
    public function predictSample(array $sample)
145
    {
146
        $sum = 0;
147
        foreach ($this->alpha as $index => $alpha) {
148
            $h = $this->classifiers[$index]->predict($sample);
149
            $sum += $h * $alpha;
150
        }
151
152
        return $this->labels[$sum > 0 ? 1 : -1];
153
    }
154
155
    /**
156
     * Returns the classifier with the lowest error rate with the
157
     * consideration of current sample weights
158
     */
159
    protected function getBestClassifier(): Classifier
160
    {
161
        $ref = new ReflectionClass($this->baseClassifier);
162
        /** @var Classifier $classifier */
163
        $classifier = count($this->classifierOptions) === 0 ? $ref->newInstance() : $ref->newInstanceArgs($this->classifierOptions);
164
165
        if ($classifier instanceof WeightedClassifier) {
166
            $classifier->setSampleWeights($this->weights);
167
            $classifier->train($this->samples, $this->targets);
168
        } else {
169
            [$samples, $targets] = $this->resample();
170
            $classifier->train($samples, $targets);
171
        }
172
173
        return $classifier;
174
    }
175
176
    /**
177
     * Resamples the dataset in accordance with the weights and
178
     * returns the new dataset
179
     */
180
    protected function resample(): array
181
    {
182
        $weights = $this->weights;
183
        $std = StandardDeviation::population($weights);
184
        $mean = Mean::arithmetic($weights);
185
        $min = min($weights);
186
        $minZ = (int) round(($min - $mean) / $std);
187
188
        $samples = [];
189
        $targets = [];
190
        foreach ($weights as $index => $weight) {
191
            $z = (int) round(($weight - $mean) / $std) - $minZ + 1;
192
            for ($i = 0; $i < $z; ++$i) {
193
                if (random_int(0, 1) == 0) {
194
                    continue;
195
                }
196
197
                $samples[] = $this->samples[$index];
198
                $targets[] = $this->targets[$index];
199
            }
200
        }
201
202
        return [$samples, $targets];
203
    }
204
205
    /**
206
     * Evaluates the classifier and returns the classification error rate
207
     */
208
    protected function evaluateClassifier(Classifier $classifier): float
209
    {
210
        $total = (float) array_sum($this->weights);
211
        $wrong = 0;
212
        foreach ($this->samples as $index => $sample) {
213
            $predicted = $classifier->predict($sample);
214
            if ($predicted != $this->targets[$index]) {
215
                $wrong += $this->weights[$index];
216
            }
217
        }
218
219
        return $wrong / $total;
220
    }
221
222
    /**
223
     * Calculates alpha of a classifier
224
     */
225
    protected function calculateAlpha(float $errorRate): float
226
    {
227
        if ($errorRate == 0) {
228
            $errorRate = 1e-10;
229
        }
230
231
        return 0.5 * log((1 - $errorRate) / $errorRate);
232
    }
233
234
    /**
235
     * Updates the sample weights
236
     */
237
    protected function updateWeights(Classifier $classifier, float $alpha): void
238
    {
239
        $sumOfWeights = array_sum($this->weights);
240
        $weightsT1 = [];
241
        foreach ($this->weights as $index => $weight) {
242
            $desired = $this->targets[$index];
243
            $output = $classifier->predict($this->samples[$index]);
244
245
            $weight *= exp(-$alpha * $desired * $output) / $sumOfWeights;
246
247
            $weightsT1[] = $weight;
248
        }
249
250
        $this->weights = $weightsT1;
251
    }
252
}
253