Passed
Push — master ( 47cdff...ed5fc8 )
by Arkadiusz
03:38
created

src/Phpml/Classification/Ensemble/AdaBoost.php (3 issues)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Ensemble;
6
7
use Phpml\Classification\Linear\DecisionStump;
8
use Phpml\Classification\WeightedClassifier;
9
use Phpml\Math\Statistic\Mean;
10
use Phpml\Math\Statistic\StandardDeviation;
11
use Phpml\Classification\Classifier;
12
use Phpml\Helper\Predictable;
13
use Phpml\Helper\Trainable;
14
15
class AdaBoost implements Classifier
16
{
17
    use Predictable, Trainable;
18
19
    /**
20
     * Actual labels given in the targets array
21
     * @var array
22
     */
23
    protected $labels = [];
24
25
    /**
26
     * @var int
27
     */
28
    protected $sampleCount;
29
30
    /**
31
     * @var int
32
     */
33
    protected $featureCount;
34
35
    /**
36
     * Number of maximum iterations to be done
37
     *
38
     * @var int
39
     */
40
    protected $maxIterations;
41
42
    /**
43
     * Sample weights
44
     *
45
     * @var array
46
     */
47
    protected $weights = [];
48
49
    /**
50
     * List of selected 'weak' classifiers
51
     *
52
     * @var array
53
     */
54
    protected $classifiers = [];
55
56
    /**
57
     * Base classifier weights
58
     *
59
     * @var array
60
     */
61
    protected $alpha = [];
62
63
    /**
64
     * @var string
65
     */
66
    protected $baseClassifier = DecisionStump::class;
67
68
    /**
69
     * @var array
70
     */
71
    protected $classifierOptions = [];
72
73
    /**
74
     * ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
75
     * improve classification performance of 'weak' classifiers such as
76
     * DecisionStump (default base classifier of AdaBoost).
77
     *
78
     * @param int $maxIterations
79
     */
80
    public function __construct(int $maxIterations = 50)
81
    {
82
        $this->maxIterations = $maxIterations;
83
    }
84
85
    /**
86
     * Sets the base classifier that will be used for boosting (default = DecisionStump)
87
     *
88
     * @param string $baseClassifier
89
     * @param array $classifierOptions
90
     */
91
    public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = [])
92
    {
93
        $this->baseClassifier = $baseClassifier;
94
        $this->classifierOptions = $classifierOptions;
95
    }
96
97
    /**
98
     * @param array $samples
99
     * @param array $targets
100
     *
101
     * @throws \Exception
102
     */
103
    public function train(array $samples, array $targets)
104
    {
105
        // Initialize usual variables
106
        $this->labels = array_keys(array_count_values($targets));
107
        if (count($this->labels) != 2) {
108
            throw new \Exception("AdaBoost is a binary classifier and can classify between two classes only");
109
        }
110
111
        // Set all target values to either -1 or 1
112
        $this->labels = [1 => $this->labels[0], -1 => $this->labels[1]];
113
        foreach ($targets as $target) {
114
            $this->targets[] = $target == $this->labels[1] ? 1 : -1;
115
        }
116
117
        $this->samples = array_merge($this->samples, $samples);
118
        $this->featureCount = count($samples[0]);
119
        $this->sampleCount = count($this->samples);
120
121
        // Initialize AdaBoost parameters
122
        $this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount);
123
        $this->classifiers = [];
124
        $this->alpha = [];
125
126
        // Execute the algorithm for a maximum number of iterations
127
        $currIter = 0;
128
        while ($this->maxIterations > $currIter++) {
129
            // Determine the best 'weak' classifier based on current weights
130
            $classifier = $this->getBestClassifier();
131
            $errorRate = $this->evaluateClassifier($classifier);
132
133
            // Update alpha & weight values at each iteration
134
            $alpha = $this->calculateAlpha($errorRate);
135
            $this->updateWeights($classifier, $alpha);
136
137
            $this->classifiers[] = $classifier;
138
            $this->alpha[] = $alpha;
139
        }
140
    }
141
142
    /**
143
     * Returns the classifier with the lowest error rate with the
144
     * consideration of current sample weights
145
     *
146
     * @return Classifier
147
     */
148
    protected function getBestClassifier()
149
    {
150
        $ref = new \ReflectionClass($this->baseClassifier);
151
        if ($this->classifierOptions) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $this->classifierOptions of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
152
            $classifier = $ref->newInstanceArgs($this->classifierOptions);
153
        } else {
154
            $classifier = $ref->newInstance();
155
        }
156
157
        if (is_subclass_of($classifier, WeightedClassifier::class)) {
0 ignored issues
show
Due to PHP Bug #53727, is_subclass_of might return inconsistent results on some PHP versions if \Phpml\Classification\WeightedClassifier::class can be an interface. If so, you could instead use ReflectionClass::implementsInterface.
Loading history...
158
            $classifier->setSampleWeights($this->weights);
159
            $classifier->train($this->samples, $this->targets);
160
        } else {
161
            list($samples, $targets) = $this->resample();
162
            $classifier->train($samples, $targets);
163
        }
164
165
        return $classifier;
166
    }
167
168
    /**
169
     * Resamples the dataset in accordance with the weights and
170
     * returns the new dataset
171
     *
172
     * @return array
173
     */
174
    protected function resample()
175
    {
176
        $weights = $this->weights;
177
        $std = StandardDeviation::population($weights);
178
        $mean= Mean::arithmetic($weights);
179
        $min = min($weights);
180
        $minZ= (int)round(($min - $mean) / $std);
181
182
        $samples = [];
183
        $targets = [];
184
        foreach ($weights as $index => $weight) {
185
            $z = (int)round(($weight - $mean) / $std) - $minZ + 1;
186 View Code Duplication
            for ($i = 0; $i < $z; ++$i) {
0 ignored issues
show
This code seems to be duplicated across your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
187
                if (rand(0, 1) == 0) {
188
                    continue;
189
                }
190
                $samples[] = $this->samples[$index];
191
                $targets[] = $this->targets[$index];
192
            }
193
        }
194
195
        return [$samples, $targets];
196
    }
197
198
    /**
199
     * Evaluates the classifier and returns the classification error rate
200
     *
201
     * @param Classifier $classifier
202
     *
203
     * @return float
204
     */
205
    protected function evaluateClassifier(Classifier $classifier)
206
    {
207
        $total = (float) array_sum($this->weights);
208
        $wrong = 0;
209
        foreach ($this->samples as $index => $sample) {
210
            $predicted = $classifier->predict($sample);
211
            if ($predicted != $this->targets[$index]) {
212
                $wrong += $this->weights[$index];
213
            }
214
        }
215
216
        return $wrong / $total;
217
    }
218
219
    /**
220
     * Calculates alpha of a classifier
221
     *
222
     * @param float $errorRate
223
     * @return float
224
     */
225
    protected function calculateAlpha(float $errorRate)
226
    {
227
        if ($errorRate == 0) {
228
            $errorRate = 1e-10;
229
        }
230
        return 0.5 * log((1 - $errorRate) / $errorRate);
231
    }
232
233
    /**
234
     * Updates the sample weights
235
     *
236
     * @param Classifier $classifier
237
     * @param float $alpha
238
     */
239
    protected function updateWeights(Classifier $classifier, float $alpha)
240
    {
241
        $sumOfWeights = array_sum($this->weights);
242
        $weightsT1 = [];
243
        foreach ($this->weights as $index => $weight) {
244
            $desired = $this->targets[$index];
245
            $output = $classifier->predict($this->samples[$index]);
246
247
            $weight *= exp(-$alpha * $desired * $output) / $sumOfWeights;
248
249
            $weightsT1[] = $weight;
250
        }
251
252
        $this->weights = $weightsT1;
253
    }
254
255
    /**
256
     * @param array $sample
257
     * @return mixed
258
     */
259
    public function predictSample(array $sample)
260
    {
261
        $sum = 0;
262
        foreach ($this->alpha as $index => $alpha) {
263
            $h = $this->classifiers[$index]->predict($sample);
264
            $sum += $h * $alpha;
265
        }
266
267
        return $this->labels[ $sum > 0 ? 1 : -1];
268
    }
269
}
270