Passed
Push — master ( e83f7b...d953ef )
by Arkadiusz
03:28
created

src/Phpml/Classification/Ensemble/AdaBoost.php (2 issues)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Ensemble;
6
7
use Exception;
8
use Phpml\Classification\Classifier;
9
use Phpml\Classification\Linear\DecisionStump;
10
use Phpml\Classification\WeightedClassifier;
11
use Phpml\Helper\Predictable;
12
use Phpml\Helper\Trainable;
13
use Phpml\Math\Statistic\Mean;
14
use Phpml\Math\Statistic\StandardDeviation;
15
use ReflectionClass;
16
17
class AdaBoost implements Classifier
18
{
19
    use Predictable, Trainable;
20
21
    /**
22
     * Actual labels given in the targets array
23
     *
24
     * @var array
25
     */
26
    protected $labels = [];
27
28
    /**
29
     * @var int
30
     */
31
    protected $sampleCount;
32
33
    /**
34
     * @var int
35
     */
36
    protected $featureCount;
37
38
    /**
39
     * Number of maximum iterations to be done
40
     *
41
     * @var int
42
     */
43
    protected $maxIterations;
44
45
    /**
46
     * Sample weights
47
     *
48
     * @var array
49
     */
50
    protected $weights = [];
51
52
    /**
53
     * List of selected 'weak' classifiers
54
     *
55
     * @var array
56
     */
57
    protected $classifiers = [];
58
59
    /**
60
     * Base classifier weights
61
     *
62
     * @var array
63
     */
64
    protected $alpha = [];
65
66
    /**
67
     * @var string
68
     */
69
    protected $baseClassifier = DecisionStump::class;
70
71
    /**
72
     * @var array
73
     */
74
    protected $classifierOptions = [];
75
76
    /**
77
     * ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
78
     * improve classification performance of 'weak' classifiers such as
79
     * DecisionStump (default base classifier of AdaBoost).
80
     */
81
    public function __construct(int $maxIterations = 50)
82
    {
83
        $this->maxIterations = $maxIterations;
84
    }
85
86
    /**
87
     * Sets the base classifier that will be used for boosting (default = DecisionStump)
88
     */
89
    public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = []): void
90
    {
91
        $this->baseClassifier = $baseClassifier;
92
        $this->classifierOptions = $classifierOptions;
93
    }
94
95
    /**
96
     * @throws \Exception
97
     */
98
    public function train(array $samples, array $targets): void
99
    {
100
        // Initialize usual variables
101
        $this->labels = array_keys(array_count_values($targets));
102
        if (count($this->labels) != 2) {
103
            throw new Exception('AdaBoost is a binary classifier and can classify between two classes only');
104
        }
105
106
        // Set all target values to either -1 or 1
107
        $this->labels = [
108
            1 => $this->labels[0],
109
            -1 => $this->labels[1],
110
        ];
111
        foreach ($targets as $target) {
112
            $this->targets[] = $target == $this->labels[1] ? 1 : -1;
113
        }
114
115
        $this->samples = array_merge($this->samples, $samples);
116
        $this->featureCount = count($samples[0]);
117
        $this->sampleCount = count($this->samples);
118
119
        // Initialize AdaBoost parameters
120
        $this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount);
121
        $this->classifiers = [];
122
        $this->alpha = [];
123
124
        // Execute the algorithm for a maximum number of iterations
125
        $currIter = 0;
126
        while ($this->maxIterations > $currIter++) {
127
            // Determine the best 'weak' classifier based on current weights
128
            $classifier = $this->getBestClassifier();
129
            $errorRate = $this->evaluateClassifier($classifier);
130
131
            // Update alpha & weight values at each iteration
132
            $alpha = $this->calculateAlpha($errorRate);
133
            $this->updateWeights($classifier, $alpha);
134
135
            $this->classifiers[] = $classifier;
136
            $this->alpha[] = $alpha;
137
        }
138
    }
139
140
    /**
141
     * @return mixed
142
     */
143
    public function predictSample(array $sample)
144
    {
145
        $sum = 0;
146
        foreach ($this->alpha as $index => $alpha) {
147
            $h = $this->classifiers[$index]->predict($sample);
148
            $sum += $h * $alpha;
149
        }
150
151
        return $this->labels[$sum > 0 ? 1 : -1];
152
    }
153
154
    /**
155
     * Returns the classifier with the lowest error rate with the
156
     * consideration of current sample weights
157
     */
158
    protected function getBestClassifier(): Classifier
159
    {
160
        $ref = new ReflectionClass($this->baseClassifier);
161
        if ($this->classifierOptions) {
162
            $classifier = $ref->newInstanceArgs($this->classifierOptions);
163
        } else {
164
            $classifier = $ref->newInstance();
165
        }
166
167
        if (is_subclass_of($classifier, WeightedClassifier::class)) {
168
            $classifier->setSampleWeights($this->weights);
169
            $classifier->train($this->samples, $this->targets);
170
        } else {
171
            [$samples, $targets] = $this->resample();
0 ignored issues
show
The variable $samples does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
The variable $targets does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
172
            $classifier->train($samples, $targets);
173
        }
174
175
        return $classifier;
176
    }
177
178
    /**
179
     * Resamples the dataset in accordance with the weights and
180
     * returns the new dataset
181
     */
182
    protected function resample(): array
183
    {
184
        $weights = $this->weights;
185
        $std = StandardDeviation::population($weights);
186
        $mean = Mean::arithmetic($weights);
187
        $min = min($weights);
188
        $minZ = (int) round(($min - $mean) / $std);
189
190
        $samples = [];
191
        $targets = [];
192
        foreach ($weights as $index => $weight) {
193
            $z = (int) round(($weight - $mean) / $std) - $minZ + 1;
194
            for ($i = 0; $i < $z; ++$i) {
195
                if (random_int(0, 1) == 0) {
196
                    continue;
197
                }
198
199
                $samples[] = $this->samples[$index];
200
                $targets[] = $this->targets[$index];
201
            }
202
        }
203
204
        return [$samples, $targets];
205
    }
206
207
    /**
208
     * Evaluates the classifier and returns the classification error rate
209
     */
210
    protected function evaluateClassifier(Classifier $classifier): float
211
    {
212
        $total = (float) array_sum($this->weights);
213
        $wrong = 0;
214
        foreach ($this->samples as $index => $sample) {
215
            $predicted = $classifier->predict($sample);
216
            if ($predicted != $this->targets[$index]) {
217
                $wrong += $this->weights[$index];
218
            }
219
        }
220
221
        return $wrong / $total;
222
    }
223
224
    /**
225
     * Calculates alpha of a classifier
226
     */
227
    protected function calculateAlpha(float $errorRate): float
228
    {
229
        if ($errorRate == 0) {
230
            $errorRate = 1e-10;
231
        }
232
233
        return 0.5 * log((1 - $errorRate) / $errorRate);
234
    }
235
236
    /**
237
     * Updates the sample weights
238
     */
239
    protected function updateWeights(Classifier $classifier, float $alpha): void
240
    {
241
        $sumOfWeights = array_sum($this->weights);
242
        $weightsT1 = [];
243
        foreach ($this->weights as $index => $weight) {
244
            $desired = $this->targets[$index];
245
            $output = $classifier->predict($this->samples[$index]);
246
247
            $weight *= exp(-$alpha * $desired * $output) / $sumOfWeights;
248
249
            $weightsT1[] = $weight;
250
        }
251
252
        $this->weights = $weightsT1;
253
    }
254
}
255