Test Setup Failed
Push — master ( 3baf15...4590d5 )
by Arkadiusz
02:24
created

src/Classification/Ensemble/AdaBoost.php (2 issues)

Labels
Severity

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Ensemble;
6
7
use Phpml\Classification\Classifier;
8
use Phpml\Classification\Linear\DecisionStump;
9
use Phpml\Classification\WeightedClassifier;
10
use Phpml\Exception\InvalidArgumentException;
11
use Phpml\Helper\Predictable;
12
use Phpml\Helper\Trainable;
13
use Phpml\Math\Statistic\Mean;
14
use Phpml\Math\Statistic\StandardDeviation;
15
use ReflectionClass;
16
17
class AdaBoost implements Classifier
18
{
19
    use Predictable;
20
    use Trainable;
21
22
    /**
23
     * Actual labels given in the targets array
24
     *
25
     * @var array
26
     */
27
    protected $labels = [];
28
29
    /**
30
     * @var int
31
     */
32
    protected $sampleCount;
33
34
    /**
35
     * @var int
36
     */
37
    protected $featureCount;
38
39
    /**
40
     * Number of maximum iterations to be done
41
     *
42
     * @var int
43
     */
44
    protected $maxIterations;
45
46
    /**
47
     * Sample weights
48
     *
49
     * @var array
50
     */
51
    protected $weights = [];
52
53
    /**
54
     * List of selected 'weak' classifiers
55
     *
56
     * @var array
57
     */
58
    protected $classifiers = [];
59
60
    /**
61
     * Base classifier weights
62
     *
63
     * @var array
64
     */
65
    protected $alpha = [];
66
67
    /**
68
     * @var string
69
     */
70
    protected $baseClassifier = DecisionStump::class;
71
72
    /**
73
     * @var array
74
     */
75
    protected $classifierOptions = [];
76
77
    /**
78
     * ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
79
     * improve classification performance of 'weak' classifiers such as
80
     * DecisionStump (default base classifier of AdaBoost).
81
     */
82
    public function __construct(int $maxIterations = 50)
83
    {
84
        $this->maxIterations = $maxIterations;
85
    }
86
87
    /**
88
     * Sets the base classifier that will be used for boosting (default = DecisionStump)
89
     */
90
    public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = []): void
91
    {
92
        $this->baseClassifier = $baseClassifier;
93
        $this->classifierOptions = $classifierOptions;
94
    }
95
96
    /**
97
     * @throws InvalidArgumentException
98
     */
99
    public function train(array $samples, array $targets): void
100
    {
101
        // Initialize usual variables
102
        $this->labels = array_keys(array_count_values($targets));
103
        if (count($this->labels) !== 2) {
104
            throw new InvalidArgumentException('AdaBoost is a binary classifier and can classify between two classes only');
105
        }
106
107
        // Set all target values to either -1 or 1
108
        $this->labels = [
109
            1 => $this->labels[0],
110
            -1 => $this->labels[1],
111
        ];
112
        foreach ($targets as $target) {
113
            $this->targets[] = $target == $this->labels[1] ? 1 : -1;
114
        }
115
116
        $this->samples = array_merge($this->samples, $samples);
117
        $this->featureCount = count($samples[0]);
118
        $this->sampleCount = count($this->samples);
119
120
        // Initialize AdaBoost parameters
121
        $this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount);
122
        $this->classifiers = [];
123
        $this->alpha = [];
124
125
        // Execute the algorithm for a maximum number of iterations
126
        $currIter = 0;
127
        while ($this->maxIterations > $currIter++) {
128
            // Determine the best 'weak' classifier based on current weights
129
            $classifier = $this->getBestClassifier();
130
            $errorRate = $this->evaluateClassifier($classifier);
131
132
            // Update alpha & weight values at each iteration
133
            $alpha = $this->calculateAlpha($errorRate);
134
            $this->updateWeights($classifier, $alpha);
135
136
            $this->classifiers[] = $classifier;
137
            $this->alpha[] = $alpha;
138
        }
139
    }
140
141
    /**
142
     * @return mixed
143
     */
144
    public function predictSample(array $sample)
145
    {
146
        $sum = 0;
147
        foreach ($this->alpha as $index => $alpha) {
148
            $h = $this->classifiers[$index]->predict($sample);
149
            $sum += $h * $alpha;
150
        }
151
152
        return $this->labels[$sum > 0 ? 1 : -1];
153
    }
154
155
    /**
156
     * Returns the classifier with the lowest error rate with the
157
     * consideration of current sample weights
158
     */
159
    protected function getBestClassifier(): Classifier
160
    {
161
        $ref = new ReflectionClass($this->baseClassifier);
162
        /** @var Classifier $classifier */
163
        $classifier = count($this->classifierOptions) === 0 ? $ref->newInstance() : $ref->newInstanceArgs($this->classifierOptions);
164
165
        if ($classifier instanceof WeightedClassifier) {
166
            $classifier->setSampleWeights($this->weights);
167
            $classifier->train($this->samples, $this->targets);
168
        } else {
169
            [$samples, $targets] = $this->resample();
0 ignored issues
show
The variable $samples does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
The variable $targets does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
170
            $classifier->train($samples, $targets);
171
        }
172
173
        return $classifier;
174
    }
175
176
    /**
177
     * Resamples the dataset in accordance with the weights and
178
     * returns the new dataset
179
     */
180
    protected function resample(): array
181
    {
182
        $weights = $this->weights;
183
        $std = StandardDeviation::population($weights);
184
        $mean = Mean::arithmetic($weights);
185
        $min = min($weights);
186
        $minZ = (int) round(($min - $mean) / $std);
187
188
        $samples = [];
189
        $targets = [];
190
        foreach ($weights as $index => $weight) {
191
            $z = (int) round(($weight - $mean) / $std) - $minZ + 1;
192
            for ($i = 0; $i < $z; ++$i) {
193
                if (random_int(0, 1) == 0) {
194
                    continue;
195
                }
196
197
                $samples[] = $this->samples[$index];
198
                $targets[] = $this->targets[$index];
199
            }
200
        }
201
202
        return [$samples, $targets];
203
    }
204
205
    /**
206
     * Evaluates the classifier and returns the classification error rate
207
     */
208
    protected function evaluateClassifier(Classifier $classifier): float
209
    {
210
        $total = (float) array_sum($this->weights);
211
        $wrong = 0;
212
        foreach ($this->samples as $index => $sample) {
213
            $predicted = $classifier->predict($sample);
214
            if ($predicted != $this->targets[$index]) {
215
                $wrong += $this->weights[$index];
216
            }
217
        }
218
219
        return $wrong / $total;
220
    }
221
222
    /**
223
     * Calculates alpha of a classifier
224
     */
225
    protected function calculateAlpha(float $errorRate): float
226
    {
227
        if ($errorRate == 0) {
228
            $errorRate = 1e-10;
229
        }
230
231
        return 0.5 * log((1 - $errorRate) / $errorRate);
232
    }
233
234
    /**
235
     * Updates the sample weights
236
     */
237
    protected function updateWeights(Classifier $classifier, float $alpha): void
238
    {
239
        $sumOfWeights = array_sum($this->weights);
240
        $weightsT1 = [];
241
        foreach ($this->weights as $index => $weight) {
242
            $desired = $this->targets[$index];
243
            $output = $classifier->predict($this->samples[$index]);
244
245
            $weight *= exp(-$alpha * $desired * $output) / $sumOfWeights;
246
247
            $weightsT1[] = $weight;
248
        }
249
250
        $this->weights = $weightsT1;
251
    }
252
}
253