Passed
Push — master ( 331d4b...653c7c )
by Arkadiusz
02:19
created

src/Phpml/Classification/Ensemble/AdaBoost.php (2 issues)

Labels
Severity

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Ensemble;
6
7
use Phpml\Classification\Classifier;
8
use Phpml\Classification\Linear\DecisionStump;
9
use Phpml\Classification\WeightedClassifier;
10
use Phpml\Helper\Predictable;
11
use Phpml\Helper\Trainable;
12
use Phpml\Math\Statistic\Mean;
13
use Phpml\Math\Statistic\StandardDeviation;
14
15
class AdaBoost implements Classifier
16
{
17
    use Predictable, Trainable;
18
19
    /**
20
     * Actual labels given in the targets array
21
     *
22
     * @var array
23
     */
24
    protected $labels = [];
25
26
    /**
27
     * @var int
28
     */
29
    protected $sampleCount;
30
31
    /**
32
     * @var int
33
     */
34
    protected $featureCount;
35
36
    /**
37
     * Number of maximum iterations to be done
38
     *
39
     * @var int
40
     */
41
    protected $maxIterations;
42
43
    /**
44
     * Sample weights
45
     *
46
     * @var array
47
     */
48
    protected $weights = [];
49
50
    /**
51
     * List of selected 'weak' classifiers
52
     *
53
     * @var array
54
     */
55
    protected $classifiers = [];
56
57
    /**
58
     * Base classifier weights
59
     *
60
     * @var array
61
     */
62
    protected $alpha = [];
63
64
    /**
65
     * @var string
66
     */
67
    protected $baseClassifier = DecisionStump::class;
68
69
    /**
70
     * @var array
71
     */
72
    protected $classifierOptions = [];
73
74
    /**
75
     * ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
76
     * improve classification performance of 'weak' classifiers such as
77
     * DecisionStump (default base classifier of AdaBoost).
78
     */
79
    public function __construct(int $maxIterations = 50)
80
    {
81
        $this->maxIterations = $maxIterations;
82
    }
83
84
    /**
85
     * Sets the base classifier that will be used for boosting (default = DecisionStump)
86
     */
87
    public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = []): void
88
    {
89
        $this->baseClassifier = $baseClassifier;
90
        $this->classifierOptions = $classifierOptions;
91
    }
92
93
    /**
94
     * @throws \Exception
95
     */
96
    public function train(array $samples, array $targets): void
97
    {
98
        // Initialize usual variables
99
        $this->labels = array_keys(array_count_values($targets));
100
        if (count($this->labels) != 2) {
101
            throw new \Exception('AdaBoost is a binary classifier and can classify between two classes only');
102
        }
103
104
        // Set all target values to either -1 or 1
105
        $this->labels = [1 => $this->labels[0], -1 => $this->labels[1]];
106
        foreach ($targets as $target) {
107
            $this->targets[] = $target == $this->labels[1] ? 1 : -1;
108
        }
109
110
        $this->samples = array_merge($this->samples, $samples);
111
        $this->featureCount = count($samples[0]);
112
        $this->sampleCount = count($this->samples);
113
114
        // Initialize AdaBoost parameters
115
        $this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount);
116
        $this->classifiers = [];
117
        $this->alpha = [];
118
119
        // Execute the algorithm for a maximum number of iterations
120
        $currIter = 0;
121
        while ($this->maxIterations > $currIter++) {
122
            // Determine the best 'weak' classifier based on current weights
123
            $classifier = $this->getBestClassifier();
124
            $errorRate = $this->evaluateClassifier($classifier);
125
126
            // Update alpha & weight values at each iteration
127
            $alpha = $this->calculateAlpha($errorRate);
128
            $this->updateWeights($classifier, $alpha);
129
130
            $this->classifiers[] = $classifier;
131
            $this->alpha[] = $alpha;
132
        }
133
    }
134
135
    /**
136
     * Returns the classifier with the lowest error rate with the
137
     * consideration of current sample weights
138
     */
139
    protected function getBestClassifier() : Classifier
140
    {
141
        $ref = new \ReflectionClass($this->baseClassifier);
142
        if ($this->classifierOptions) {
143
            $classifier = $ref->newInstanceArgs($this->classifierOptions);
144
        } else {
145
            $classifier = $ref->newInstance();
146
        }
147
148
        if (is_subclass_of($classifier, WeightedClassifier::class)) {
149
            $classifier->setSampleWeights($this->weights);
150
            $classifier->train($this->samples, $this->targets);
151
        } else {
152
            [$samples, $targets] = $this->resample();
0 ignored issues
show
The variable $samples does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
The variable $targets does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
153
            $classifier->train($samples, $targets);
154
        }
155
156
        return $classifier;
157
    }
158
159
    /**
160
     * Resamples the dataset in accordance with the weights and
161
     * returns the new dataset
162
     */
163
    protected function resample() : array
164
    {
165
        $weights = $this->weights;
166
        $std = StandardDeviation::population($weights);
167
        $mean = Mean::arithmetic($weights);
168
        $min = min($weights);
169
        $minZ = (int) round(($min - $mean) / $std);
170
171
        $samples = [];
172
        $targets = [];
173
        foreach ($weights as $index => $weight) {
174
            $z = (int) round(($weight - $mean) / $std) - $minZ + 1;
175
            for ($i = 0; $i < $z; ++$i) {
176
                if (rand(0, 1) == 0) {
177
                    continue;
178
                }
179
                $samples[] = $this->samples[$index];
180
                $targets[] = $this->targets[$index];
181
            }
182
        }
183
184
        return [$samples, $targets];
185
    }
186
187
    /**
188
     * Evaluates the classifier and returns the classification error rate
189
     */
190
    protected function evaluateClassifier(Classifier $classifier) : float
191
    {
192
        $total = (float) array_sum($this->weights);
193
        $wrong = 0;
194
        foreach ($this->samples as $index => $sample) {
195
            $predicted = $classifier->predict($sample);
196
            if ($predicted != $this->targets[$index]) {
197
                $wrong += $this->weights[$index];
198
            }
199
        }
200
201
        return $wrong / $total;
202
    }
203
204
    /**
205
     * Calculates alpha of a classifier
206
     */
207
    protected function calculateAlpha(float $errorRate) : float
208
    {
209
        if ($errorRate == 0) {
210
            $errorRate = 1e-10;
211
        }
212
213
        return 0.5 * log((1 - $errorRate) / $errorRate);
214
    }
215
216
    /**
217
     * Updates the sample weights
218
     */
219
    protected function updateWeights(Classifier $classifier, float $alpha): void
220
    {
221
        $sumOfWeights = array_sum($this->weights);
222
        $weightsT1 = [];
223
        foreach ($this->weights as $index => $weight) {
224
            $desired = $this->targets[$index];
225
            $output = $classifier->predict($this->samples[$index]);
226
227
            $weight *= exp(-$alpha * $desired * $output) / $sumOfWeights;
228
229
            $weightsT1[] = $weight;
230
        }
231
232
        $this->weights = $weightsT1;
233
    }
234
235
    /**
236
     * @return mixed
237
     */
238
    public function predictSample(array $sample)
239
    {
240
        $sum = 0;
241
        foreach ($this->alpha as $index => $alpha) {
242
            $h = $this->classifiers[$index]->predict($sample);
243
            $sum += $h * $alpha;
244
        }
245
246
        return $this->labels[$sum > 0 ? 1 : -1];
247
    }
248
}
249