Passed
Push — master ( 331d4b...653c7c )
by Arkadiusz
02:19
created

src/Phpml/Classification/Linear/DecisionStump.php (4 issues)

Labels
Severity

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Phpml\Classification\DecisionTree;
8
use Phpml\Classification\WeightedClassifier;
9
use Phpml\Helper\OneVsRest;
10
use Phpml\Helper\Predictable;
11
use Phpml\Math\Comparison;
12
13
class DecisionStump extends WeightedClassifier
14
{
15
    use Predictable, OneVsRest;
16
17
    public const AUTO_SELECT = -1;
18
19
    /**
20
     * @var int
21
     */
22
    protected $givenColumnIndex;
23
24
    /**
25
     * @var array
26
     */
27
    protected $binaryLabels;
28
29
    /**
30
     * Lowest error rate obtained while training/optimizing the model
31
     *
32
     * @var float
33
     */
34
    protected $trainingErrorRate;
35
36
    /**
37
     * @var int
38
     */
39
    protected $column;
40
41
    /**
42
     * @var mixed
43
     */
44
    protected $value;
45
46
    /**
47
     * @var string
48
     */
49
    protected $operator;
50
51
    /**
52
     * @var array
53
     */
54
    protected $columnTypes;
55
56
    /**
57
     * @var int
58
     */
59
    protected $featureCount;
60
61
    /**
62
     * @var float
63
     */
64
    protected $numSplitCount = 100.0;
65
66
    /**
67
     * Distribution of samples in the leaves
68
     *
69
     * @var array
70
     */
71
    protected $prob;
72
73
    /**
74
     * A DecisionStump classifier is a one-level deep DecisionTree. It is generally
75
     * used with ensemble algorithms as in the weak classifier role. <br>
76
     *
77
     * If columnIndex is given, then the stump tries to produce a decision node
78
     * on this column, otherwise in cases given the value of -1, the stump itself
79
     * decides which column to take for the decision (Default DecisionTree behaviour)
80
     */
81
    public function __construct(int $columnIndex = self::AUTO_SELECT)
82
    {
83
        $this->givenColumnIndex = $columnIndex;
84
    }
85
86
    /**
87
     * @throws \Exception
88
     */
89
    protected function trainBinary(array $samples, array $targets, array $labels): void
90
    {
91
        $this->binaryLabels = $labels;
92
        $this->featureCount = count($samples[0]);
93
94
        // If a column index is given, it should be among the existing columns
95
        if ($this->givenColumnIndex > count($samples[0]) - 1) {
96
            $this->givenColumnIndex = self::AUTO_SELECT;
97
        }
98
99
        // Check the size of the weights given.
100
        // If none given, then assign 1 as a weight to each sample
101
        if ($this->weights) {
102
            $numWeights = count($this->weights);
103
            if ($numWeights != count($samples)) {
104
                throw new \Exception('Number of sample weights does not match with number of samples');
105
            }
106
        } else {
107
            $this->weights = array_fill(0, count($samples), 1);
108
        }
109
110
        // Determine type of each column as either "continuous" or "nominal"
111
        $this->columnTypes = DecisionTree::getColumnTypes($samples);
112
113
        // Try to find the best split in the columns of the dataset
114
        // by calculating error rate for each split point in each column
115
        $columns = range(0, count($samples[0]) - 1);
116
        if ($this->givenColumnIndex != self::AUTO_SELECT) {
117
            $columns = [$this->givenColumnIndex];
118
        }
119
120
        $bestSplit = [
121
            'value' => 0, 'operator' => '',
122
            'prob' => [], 'column' => 0,
123
            'trainingErrorRate' => 1.0];
124
        foreach ($columns as $col) {
125
            if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) {
126
                $split = $this->getBestNumericalSplit($samples, $targets, $col);
127
            } else {
128
                $split = $this->getBestNominalSplit($samples, $targets, $col);
129
            }
130
131
            if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) {
132
                $bestSplit = $split;
133
            }
134
        }
135
136
        // Assign determined best values to the stump
137
        foreach ($bestSplit as $name => $value) {
138
            $this->{$name} = $value;
139
        }
140
    }
141
142
    /**
143
     * While finding best split point for a numerical valued column,
144
     * DecisionStump looks for equally distanced values between minimum and maximum
145
     * values in the column. Given <i>$count</i> value determines how many split
146
     * points to be probed. The more split counts, the better performance but
147
     * worse processing time (Default value is 10.0)
148
     */
149
    public function setNumericalSplitCount(float $count): void
150
    {
151
        $this->numSplitCount = $count;
152
    }
153
154
    /**
155
     * Determines best split point for the given column
156
     */
157
    protected function getBestNumericalSplit(array $samples, array $targets, int $col) : array
158
    {
159
        $values = array_column($samples, $col);
160
        // Trying all possible points may be accomplished in two general ways:
161
        // 1- Try all values in the $samples array ($values)
162
        // 2- Artificially split the range of values into several parts and try them
163
        // We choose the second one because it is faster in larger datasets
164
        $minValue = min($values);
165
        $maxValue = max($values);
166
        $stepSize = ($maxValue - $minValue) / $this->numSplitCount;
167
168
        $split = null;
169
170
        foreach (['<=', '>'] as $operator) {
171
            // Before trying all possible split points, let's first try
172
            // the average value for the cut point
173
            $threshold = array_sum($values) / (float) count($values);
174
            [$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values);
0 ignored issues
show
The variable $errorRate does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
The variable $prob does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
175
            if ($split == null || $errorRate < $split['trainingErrorRate']) {
176
                $split = ['value' => $threshold, 'operator' => $operator,
177
                        'prob' => $prob, 'column' => $col,
178
                        'trainingErrorRate' => $errorRate];
179
            }
180
181
            // Try other possible points one by one
182
            for ($step = $minValue; $step <= $maxValue; $step += $stepSize) {
183
                $threshold = (float) $step;
184
                [$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values);
185
                if ($errorRate < $split['trainingErrorRate']) {
186
                    $split = ['value' => $threshold, 'operator' => $operator,
187
                        'prob' => $prob, 'column' => $col,
188
                        'trainingErrorRate' => $errorRate];
189
                }
190
            }// for
191
        }
192
193
        return $split;
194
    }
195
196
    protected function getBestNominalSplit(array $samples, array $targets, int $col) : array
197
    {
198
        $values = array_column($samples, $col);
199
        $valueCounts = array_count_values($values);
200
        $distinctVals = array_keys($valueCounts);
201
202
        $split = null;
203
204
        foreach (['=', '!='] as $operator) {
205
            foreach ($distinctVals as $val) {
206
                [$errorRate, $prob] = $this->calculateErrorRate($targets, $val, $operator, $values);
0 ignored issues
show
The variable $errorRate does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
The variable $prob does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
207
208
                if ($split == null || $split['trainingErrorRate'] < $errorRate) {
209
                    $split = ['value' => $val, 'operator' => $operator,
210
                        'prob' => $prob, 'column' => $col,
211
                        'trainingErrorRate' => $errorRate];
212
                }
213
            }
214
        }
215
216
        return $split;
217
    }
218
219
    /**
220
     * Calculates the ratio of wrong predictions based on the new threshold
221
     * value given as the parameter
222
     */
223
    protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values) : array
224
    {
225
        $wrong = 0.0;
226
        $prob = [];
227
        $leftLabel = $this->binaryLabels[0];
228
        $rightLabel = $this->binaryLabels[1];
229
230
        foreach ($values as $index => $value) {
231
            if (Comparison::compare($value, $threshold, $operator)) {
232
                $predicted = $leftLabel;
233
            } else {
234
                $predicted = $rightLabel;
235
            }
236
237
            $target = $targets[$index];
238
            if ((string) $predicted != (string) $targets[$index]) {
239
                $wrong += $this->weights[$index];
240
            }
241
242
            if (!isset($prob[$predicted][$target])) {
243
                $prob[$predicted][$target] = 0;
244
            }
245
            ++$prob[$predicted][$target];
246
        }
247
248
        // Calculate probabilities: Proportion of labels in each leaf
249
        $dist = array_combine($this->binaryLabels, array_fill(0, 2, 0.0));
250
        foreach ($prob as $leaf => $counts) {
251
            $leafTotal = (float) array_sum($prob[$leaf]);
252
            foreach ($counts as $label => $count) {
253
                if ((string) $leaf == (string) $label) {
254
                    $dist[$leaf] = $count / $leafTotal;
255
                }
256
            }
257
        }
258
259
        return [$wrong / (float) array_sum($this->weights), $dist];
260
    }
261
262
    /**
263
     * Returns the probability of the sample of belonging to the given label
264
     *
265
     * Probability of a sample is calculated as the proportion of the label
266
     * within the labels of the training samples in the decision node
267
     *
268
     * @param mixed $label
269
     */
270
    protected function predictProbability(array $sample, $label) : float
271
    {
272
        $predicted = $this->predictSampleBinary($sample);
273
        if ((string) $predicted == (string) $label) {
274
            return $this->prob[$label];
275
        }
276
277
        return 0.0;
278
    }
279
280
    /**
281
     * @return mixed
282
     */
283
    protected function predictSampleBinary(array $sample)
284
    {
285
        if (Comparison::compare($sample[$this->column], $this->value, $this->operator)) {
286
            return $this->binaryLabels[0];
287
        }
288
289
        return $this->binaryLabels[1];
290
    }
291
292
    protected function resetBinary(): void
293
    {
294
    }
295
296
    public function __toString() : string
297
    {
298
        return "IF $this->column $this->operator $this->value ".
299
            'THEN '.$this->binaryLabels[0].' '.
300
            'ELSE '.$this->binaryLabels[1];
301
    }
302
}
303