Passed
Push — master ( e83f7b...d953ef )
by Arkadiusz
03:28
created

src/Phpml/Classification/Ensemble/AdaBoost.php (1 issue)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Ensemble;
6
7
use Exception;
8
use Phpml\Classification\Classifier;
9
use Phpml\Classification\Linear\DecisionStump;
10
use Phpml\Classification\WeightedClassifier;
11
use Phpml\Helper\Predictable;
12
use Phpml\Helper\Trainable;
13
use Phpml\Math\Statistic\Mean;
14
use Phpml\Math\Statistic\StandardDeviation;
15
use ReflectionClass;
16
17
class AdaBoost implements Classifier
18
{
19
    use Predictable, Trainable;
20
21
    /**
22
     * Actual labels given in the targets array
23
     *
24
     * @var array
25
     */
26
    protected $labels = [];
27
28
    /**
29
     * @var int
30
     */
31
    protected $sampleCount;
32
33
    /**
34
     * @var int
35
     */
36
    protected $featureCount;
37
38
    /**
39
     * Number of maximum iterations to be done
40
     *
41
     * @var int
42
     */
43
    protected $maxIterations;
44
45
    /**
46
     * Sample weights
47
     *
48
     * @var array
49
     */
50
    protected $weights = [];
51
52
    /**
53
     * List of selected 'weak' classifiers
54
     *
55
     * @var array
56
     */
57
    protected $classifiers = [];
58
59
    /**
60
     * Base classifier weights
61
     *
62
     * @var array
63
     */
64
    protected $alpha = [];
65
66
    /**
67
     * @var string
68
     */
69
    protected $baseClassifier = DecisionStump::class;
70
71
    /**
72
     * @var array
73
     */
74
    protected $classifierOptions = [];
75
76
    /**
77
     * ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
78
     * improve classification performance of 'weak' classifiers such as
79
     * DecisionStump (default base classifier of AdaBoost).
80
     */
81
    public function __construct(int $maxIterations = 50)
82
    {
83
        $this->maxIterations = $maxIterations;
84
    }
85
86
    /**
87
     * Sets the base classifier that will be used for boosting (default = DecisionStump)
88
     */
89
    public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = []): void
90
    {
91
        $this->baseClassifier = $baseClassifier;
92
        $this->classifierOptions = $classifierOptions;
93
    }
94
95
    /**
96
     * @throws \Exception
97
     */
98
    public function train(array $samples, array $targets): void
99
    {
100
        // Initialize usual variables
101
        $this->labels = array_keys(array_count_values($targets));
102
        if (count($this->labels) != 2) {
103
            throw new Exception('AdaBoost is a binary classifier and can classify between two classes only');
104
        }
105
106
        // Set all target values to either -1 or 1
107
        $this->labels = [
108
            1 => $this->labels[0],
109
            -1 => $this->labels[1],
110
        ];
111
        foreach ($targets as $target) {
112
            $this->targets[] = $target == $this->labels[1] ? 1 : -1;
113
        }
114
115
        $this->samples = array_merge($this->samples, $samples);
116
        $this->featureCount = count($samples[0]);
117
        $this->sampleCount = count($this->samples);
118
119
        // Initialize AdaBoost parameters
120
        $this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount);
121
        $this->classifiers = [];
122
        $this->alpha = [];
123
124
        // Execute the algorithm for a maximum number of iterations
125
        $currIter = 0;
126
        while ($this->maxIterations > $currIter++) {
127
            // Determine the best 'weak' classifier based on current weights
128
            $classifier = $this->getBestClassifier();
129
            $errorRate = $this->evaluateClassifier($classifier);
130
131
            // Update alpha & weight values at each iteration
132
            $alpha = $this->calculateAlpha($errorRate);
133
            $this->updateWeights($classifier, $alpha);
134
135
            $this->classifiers[] = $classifier;
136
            $this->alpha[] = $alpha;
137
        }
138
    }
139
140
    /**
141
     * @return mixed
142
     */
143
    public function predictSample(array $sample)
144
    {
145
        $sum = 0;
146
        foreach ($this->alpha as $index => $alpha) {
147
            $h = $this->classifiers[$index]->predict($sample);
148
            $sum += $h * $alpha;
149
        }
150
151
        return $this->labels[$sum > 0 ? 1 : -1];
152
    }
153
154
    /**
155
     * Returns the classifier with the lowest error rate with the
156
     * consideration of current sample weights
157
     */
158
    protected function getBestClassifier(): Classifier
159
    {
160
        $ref = new ReflectionClass($this->baseClassifier);
161
        if ($this->classifierOptions) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $this->classifierOptions of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
162
            $classifier = $ref->newInstanceArgs($this->classifierOptions);
163
        } else {
164
            $classifier = $ref->newInstance();
165
        }
166
167
        if (is_subclass_of($classifier, WeightedClassifier::class)) {
168
            $classifier->setSampleWeights($this->weights);
169
            $classifier->train($this->samples, $this->targets);
170
        } else {
171
            [$samples, $targets] = $this->resample();
172
            $classifier->train($samples, $targets);
173
        }
174
175
        return $classifier;
176
    }
177
178
    /**
179
     * Resamples the dataset in accordance with the weights and
180
     * returns the new dataset
181
     */
182
    protected function resample(): array
183
    {
184
        $weights = $this->weights;
185
        $std = StandardDeviation::population($weights);
186
        $mean = Mean::arithmetic($weights);
187
        $min = min($weights);
188
        $minZ = (int) round(($min - $mean) / $std);
189
190
        $samples = [];
191
        $targets = [];
192
        foreach ($weights as $index => $weight) {
193
            $z = (int) round(($weight - $mean) / $std) - $minZ + 1;
194
            for ($i = 0; $i < $z; ++$i) {
195
                if (random_int(0, 1) == 0) {
196
                    continue;
197
                }
198
199
                $samples[] = $this->samples[$index];
200
                $targets[] = $this->targets[$index];
201
            }
202
        }
203
204
        return [$samples, $targets];
205
    }
206
207
    /**
208
     * Evaluates the classifier and returns the classification error rate
209
     */
210
    protected function evaluateClassifier(Classifier $classifier): float
211
    {
212
        $total = (float) array_sum($this->weights);
213
        $wrong = 0;
214
        foreach ($this->samples as $index => $sample) {
215
            $predicted = $classifier->predict($sample);
216
            if ($predicted != $this->targets[$index]) {
217
                $wrong += $this->weights[$index];
218
            }
219
        }
220
221
        return $wrong / $total;
222
    }
223
224
    /**
225
     * Calculates alpha of a classifier
226
     */
227
    protected function calculateAlpha(float $errorRate): float
228
    {
229
        if ($errorRate == 0) {
230
            $errorRate = 1e-10;
231
        }
232
233
        return 0.5 * log((1 - $errorRate) / $errorRate);
234
    }
235
236
    /**
237
     * Updates the sample weights
238
     */
239
    protected function updateWeights(Classifier $classifier, float $alpha): void
240
    {
241
        $sumOfWeights = array_sum($this->weights);
242
        $weightsT1 = [];
243
        foreach ($this->weights as $index => $weight) {
244
            $desired = $this->targets[$index];
245
            $output = $classifier->predict($this->samples[$index]);
246
247
            $weight *= exp(-$alpha * $desired * $output) / $sumOfWeights;
248
249
            $weightsT1[] = $weight;
250
        }
251
252
        $this->weights = $weightsT1;
253
    }
254
}
255