Completed
Push — master ( cf222b...4daa0a )
by Arkadiusz
03:24
created

AdaBoost::updateWeights()   A

Complexity

Conditions 2
Paths 2

Size

Total Lines 15
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 15
rs 9.4285
c 0
b 0
f 0
cc 2
eloc 9
nc 2
nop 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Ensemble;
6
7
use Phpml\Classification\Linear\DecisionStump;
8
use Phpml\Classification\Classifier;
9
use Phpml\Helper\Predictable;
10
use Phpml\Helper\Trainable;
11
12
class AdaBoost implements Classifier
13
{
14
    use Predictable, Trainable;
15
16
    /**
17
     * Actual labels given in the targets array
18
     * @var array
19
     */
20
    protected $labels = [];
21
22
    /**
23
     * @var int
24
     */
25
    protected $sampleCount;
26
27
    /**
28
     * @var int
29
     */
30
    protected $featureCount;
31
32
    /**
33
     * Number of maximum iterations to be done
34
     *
35
     * @var int
36
     */
37
    protected $maxIterations;
38
39
    /**
40
     * Sample weights
41
     *
42
     * @var array
43
     */
44
    protected $weights = [];
45
46
    /**
47
     * Base classifiers
48
     *
49
     * @var array
50
     */
51
    protected $classifiers = [];
52
53
    /**
54
     * Base classifier weights
55
     *
56
     * @var array
57
     */
58
    protected $alpha = [];
59
60
    /**
61
     * ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
62
     * improve classification performance of 'weak' classifiers such as
63
     * DecisionStump (default base classifier of AdaBoost).
64
     *
65
     */
66
    public function __construct(int $maxIterations = 30)
67
    {
68
        $this->maxIterations = $maxIterations;
69
    }
70
71
    /**
72
     * @param array $samples
73
     * @param array $targets
74
     */
75
    public function train(array $samples, array $targets)
76
    {
77
        // Initialize usual variables
78
        $this->labels = array_keys(array_count_values($targets));
79
        if (count($this->labels) != 2) {
80
            throw new \Exception("AdaBoost is a binary classifier and can only classify between two classes");
81
        }
82
83
        // Set all target values to either -1 or 1
84
        $this->labels = [1 => $this->labels[0], -1 => $this->labels[1]];
85 View Code Duplication
        foreach ($targets as $target) {
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated across your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
86
            $this->targets[] = $target == $this->labels[1] ? 1 : -1;
87
        }
88
89
        $this->samples = array_merge($this->samples, $samples);
90
        $this->featureCount = count($samples[0]);
91
        $this->sampleCount = count($this->samples);
92
93
        // Initialize AdaBoost parameters
94
        $this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount);
95
        $this->classifiers = [];
96
        $this->alpha = [];
97
98
        // Execute the algorithm for a maximum number of iterations
99
        $currIter = 0;
100
        while ($this->maxIterations > $currIter++) {
101
            // Determine the best 'weak' classifier based on current weights
102
            // and update alpha & weight values at each iteration
103
            list($classifier, $errorRate) = $this->getBestClassifier();
104
            $alpha = $this->calculateAlpha($errorRate);
105
            $this->updateWeights($classifier, $alpha);
0 ignored issues
show
Bug introduced by
It seems like $classifier can be null; however, updateWeights() does not accept null, maybe add an additional type check?

Unless you are absolutely sure that the expression can never be null because of other conditions, we strongly recommend to add an additional type check to your code:

/** @return stdClass|null */
function mayReturnNull() { }

function doesNotAcceptNull(stdClass $x) { }

// With potential error.
function withoutCheck() {
    $x = mayReturnNull();
    doesNotAcceptNull($x); // Potential error here.
}

// Safe - Alternative 1
function withCheck1() {
    $x = mayReturnNull();
    if ( ! $x instanceof stdClass) {
        throw new \LogicException('$x must be defined.');
    }
    doesNotAcceptNull($x);
}

// Safe - Alternative 2
function withCheck2() {
    $x = mayReturnNull();
    if ($x instanceof stdClass) {
        doesNotAcceptNull($x);
    }
}
Loading history...
106
107
            $this->classifiers[] = $classifier;
108
            $this->alpha[] = $alpha;
109
        }
110
    }
111
112
    /**
113
     * Returns the classifier with the lowest error rate with the
114
     * consideration of current sample weights
115
     *
116
     * @return Classifier
117
     */
118
    protected function getBestClassifier()
119
    {
120
        // This method works only for "DecisionStump" classifier, for now.
121
        // As a future task, it will be generalized enough to work with other
122
        //  classifiers as well
123
        $minErrorRate = 1.0;
124
        $bestClassifier = null;
125
        for ($i=0; $i < $this->featureCount; $i++) {
126
            $stump = new DecisionStump($i);
127
            $stump->setSampleWeights($this->weights);
128
            $stump->train($this->samples, $this->targets);
129
130
            $errorRate = $stump->getTrainingErrorRate();
131
            if ($errorRate < $minErrorRate) {
132
                $bestClassifier = $stump;
133
                $minErrorRate = $errorRate;
134
            }
135
        }
136
137
        return [$bestClassifier, $minErrorRate];
138
    }
139
140
    /**
141
     * Calculates alpha of a classifier
142
     *
143
     * @param float $errorRate
144
     * @return float
145
     */
146
    protected function calculateAlpha(float $errorRate)
147
    {
148
        if ($errorRate == 0) {
149
            $errorRate = 1e-10;
150
        }
151
        return 0.5 * log((1 - $errorRate) / $errorRate);
152
    }
153
154
    /**
155
     * Updates the sample weights
156
     *
157
     * @param DecisionStump $classifier
158
     * @param float $alpha
159
     */
160
    protected function updateWeights(DecisionStump $classifier, float $alpha)
161
    {
162
        $sumOfWeights = array_sum($this->weights);
163
        $weightsT1 = [];
164
        foreach ($this->weights as $index => $weight) {
165
            $desired = $this->targets[$index];
166
            $output = $classifier->predict($this->samples[$index]);
167
168
            $weight *= exp(-$alpha * $desired * $output) / $sumOfWeights;
169
170
            $weightsT1[] = $weight;
171
        }
172
173
        $this->weights = $weightsT1;
174
    }
175
176
    /**
177
     * @param array $sample
178
     * @return mixed
179
     */
180
    public function predictSample(array $sample)
181
    {
182
        $sum = 0;
183
        foreach ($this->alpha as $index => $alpha) {
184
            $h = $this->classifiers[$index]->predict($sample);
185
            $sum += $h * $alpha;
186
        }
187
188
        return $this->labels[ $sum > 0 ? 1 : -1];
189
    }
190
}
191