Passed
Push — master ( e83f7b...d953ef )
by Arkadiusz
03:28
created

src/Phpml/Classification/Linear/DecisionStump.php (4 issues)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Exception;
8
use Phpml\Classification\DecisionTree;
9
use Phpml\Classification\WeightedClassifier;
10
use Phpml\Helper\OneVsRest;
11
use Phpml\Helper\Predictable;
12
use Phpml\Math\Comparison;
13
14
class DecisionStump extends WeightedClassifier
15
{
16
    use Predictable, OneVsRest;
17
18
    public const AUTO_SELECT = -1;
19
20
    /**
21
     * @var int
22
     */
23
    protected $givenColumnIndex;
24
25
    /**
26
     * @var array
27
     */
28
    protected $binaryLabels = [];
29
30
    /**
31
     * Lowest error rate obtained while training/optimizing the model
32
     *
33
     * @var float
34
     */
35
    protected $trainingErrorRate;
36
37
    /**
38
     * @var int
39
     */
40
    protected $column;
41
42
    /**
43
     * @var mixed
44
     */
45
    protected $value;
46
47
    /**
48
     * @var string
49
     */
50
    protected $operator;
51
52
    /**
53
     * @var array
54
     */
55
    protected $columnTypes = [];
56
57
    /**
58
     * @var int
59
     */
60
    protected $featureCount;
61
62
    /**
63
     * @var float
64
     */
65
    protected $numSplitCount = 100.0;
66
67
    /**
68
     * Distribution of samples in the leaves
69
     *
70
     * @var array
71
     */
72
    protected $prob = [];
73
74
    /**
75
     * A DecisionStump classifier is a one-level deep DecisionTree. It is generally
76
     * used with ensemble algorithms as in the weak classifier role. <br>
77
     *
78
     * If columnIndex is given, then the stump tries to produce a decision node
79
     * on this column, otherwise in cases given the value of -1, the stump itself
80
     * decides which column to take for the decision (Default DecisionTree behaviour)
81
     */
82
    public function __construct(int $columnIndex = self::AUTO_SELECT)
83
    {
84
        $this->givenColumnIndex = $columnIndex;
85
    }
86
87
    public function __toString(): string
88
    {
89
        return "IF ${this}->column ${this}->operator ${this}->value ".
90
            'THEN '.$this->binaryLabels[0].' '.
91
            'ELSE '.$this->binaryLabels[1];
92
    }
93
94
    /**
95
     * While finding best split point for a numerical valued column,
96
     * DecisionStump looks for equally distanced values between minimum and maximum
97
     * values in the column. Given <i>$count</i> value determines how many split
98
     * points to be probed. The more split counts, the better performance but
99
     * worse processing time (Default value is 10.0)
100
     */
101
    public function setNumericalSplitCount(float $count): void
102
    {
103
        $this->numSplitCount = $count;
104
    }
105
106
    /**
107
     * @throws \Exception
108
     */
109
    protected function trainBinary(array $samples, array $targets, array $labels): void
110
    {
111
        $this->binaryLabels = $labels;
112
        $this->featureCount = count($samples[0]);
113
114
        // If a column index is given, it should be among the existing columns
115
        if ($this->givenColumnIndex > count($samples[0]) - 1) {
116
            $this->givenColumnIndex = self::AUTO_SELECT;
117
        }
118
119
        // Check the size of the weights given.
120
        // If none given, then assign 1 as a weight to each sample
121
        if ($this->weights) {
122
            $numWeights = count($this->weights);
123
            if ($numWeights != count($samples)) {
124
                throw new Exception('Number of sample weights does not match with number of samples');
125
            }
126
        } else {
127
            $this->weights = array_fill(0, count($samples), 1);
128
        }
129
130
        // Determine type of each column as either "continuous" or "nominal"
131
        $this->columnTypes = DecisionTree::getColumnTypes($samples);
132
133
        // Try to find the best split in the columns of the dataset
134
        // by calculating error rate for each split point in each column
135
        $columns = range(0, count($samples[0]) - 1);
136
        if ($this->givenColumnIndex != self::AUTO_SELECT) {
137
            $columns = [$this->givenColumnIndex];
138
        }
139
140
        $bestSplit = [
141
            'value' => 0,
142
            'operator' => '',
143
            'prob' => [],
144
            'column' => 0,
145
            'trainingErrorRate' => 1.0,
146
        ];
147
        foreach ($columns as $col) {
148
            if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) {
149
                $split = $this->getBestNumericalSplit($samples, $targets, $col);
150
            } else {
151
                $split = $this->getBestNominalSplit($samples, $targets, $col);
152
            }
153
154
            if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) {
155
                $bestSplit = $split;
156
            }
157
        }
158
159
        // Assign determined best values to the stump
160
        foreach ($bestSplit as $name => $value) {
161
            $this->{$name} = $value;
162
        }
163
    }
164
165
    /**
166
     * Determines best split point for the given column
167
     */
168
    protected function getBestNumericalSplit(array $samples, array $targets, int $col): array
169
    {
170
        $values = array_column($samples, $col);
171
        // Trying all possible points may be accomplished in two general ways:
172
        // 1- Try all values in the $samples array ($values)
173
        // 2- Artificially split the range of values into several parts and try them
174
        // We choose the second one because it is faster in larger datasets
175
        $minValue = min($values);
176
        $maxValue = max($values);
177
        $stepSize = ($maxValue - $minValue) / $this->numSplitCount;
178
179
        $split = [];
180
181
        foreach (['<=', '>'] as $operator) {
182
            // Before trying all possible split points, let's first try
183
            // the average value for the cut point
184
            $threshold = array_sum($values) / (float) count($values);
185
            [$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values);
0 ignored issues
show
The variable $errorRate does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
The variable $prob does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
186
            if ($split === [] || $errorRate < $split['trainingErrorRate']) {
187
                $split = [
188
                    'value' => $threshold,
189
                    'operator' => $operator,
190
                    'prob' => $prob,
191
                    'column' => $col,
192
                    'trainingErrorRate' => $errorRate,
193
                ];
194
            }
195
196
            // Try other possible points one by one
197
            for ($step = $minValue; $step <= $maxValue; $step += $stepSize) {
198
                $threshold = (float) $step;
199
                [$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values);
200
                if ($errorRate < $split['trainingErrorRate']) {
201
                    $split = [
202
                        'value' => $threshold,
203
                        'operator' => $operator,
204
                        'prob' => $prob,
205
                        'column' => $col,
206
                        'trainingErrorRate' => $errorRate,
207
                    ];
208
                }
209
            }// for
210
        }
211
212
        return $split;
213
    }
214
215
    protected function getBestNominalSplit(array $samples, array $targets, int $col): array
216
    {
217
        $values = array_column($samples, $col);
218
        $valueCounts = array_count_values($values);
219
        $distinctVals = array_keys($valueCounts);
220
221
        $split = [];
222
223
        foreach (['=', '!='] as $operator) {
224
            foreach ($distinctVals as $val) {
225
                [$errorRate, $prob] = $this->calculateErrorRate($targets, $val, $operator, $values);
0 ignored issues
show
The variable $errorRate does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
The variable $prob does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
226
227
                if ($split === [] || $split['trainingErrorRate'] < $errorRate) {
228
                    $split = [
229
                        'value' => $val,
230
                        'operator' => $operator,
231
                        'prob' => $prob,
232
                        'column' => $col,
233
                        'trainingErrorRate' => $errorRate,
234
                    ];
235
                }
236
            }
237
        }
238
239
        return $split;
240
    }
241
242
    /**
243
     * Calculates the ratio of wrong predictions based on the new threshold
244
     * value given as the parameter
245
     */
246
    protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values): array
247
    {
248
        $wrong = 0.0;
249
        $prob = [];
250
        $leftLabel = $this->binaryLabels[0];
251
        $rightLabel = $this->binaryLabels[1];
252
253
        foreach ($values as $index => $value) {
254
            if (Comparison::compare($value, $threshold, $operator)) {
255
                $predicted = $leftLabel;
256
            } else {
257
                $predicted = $rightLabel;
258
            }
259
260
            $target = $targets[$index];
261
            if ((string) $predicted != (string) $targets[$index]) {
262
                $wrong += $this->weights[$index];
263
            }
264
265
            if (!isset($prob[$predicted][$target])) {
266
                $prob[$predicted][$target] = 0;
267
            }
268
269
            ++$prob[$predicted][$target];
270
        }
271
272
        // Calculate probabilities: Proportion of labels in each leaf
273
        $dist = array_combine($this->binaryLabels, array_fill(0, 2, 0.0));
274
        foreach ($prob as $leaf => $counts) {
275
            $leafTotal = (float) array_sum($prob[$leaf]);
276
            foreach ($counts as $label => $count) {
277
                if ((string) $leaf == (string) $label) {
278
                    $dist[$leaf] = $count / $leafTotal;
279
                }
280
            }
281
        }
282
283
        return [$wrong / (float) array_sum($this->weights), $dist];
284
    }
285
286
    /**
287
     * Returns the probability of the sample of belonging to the given label
288
     *
289
     * Probability of a sample is calculated as the proportion of the label
290
     * within the labels of the training samples in the decision node
291
     *
292
     * @param mixed $label
293
     */
294
    protected function predictProbability(array $sample, $label): float
295
    {
296
        $predicted = $this->predictSampleBinary($sample);
297
        if ((string) $predicted == (string) $label) {
298
            return $this->prob[$label];
299
        }
300
301
        return 0.0;
302
    }
303
304
    /**
305
     * @return mixed
306
     */
307
    protected function predictSampleBinary(array $sample)
308
    {
309
        if (Comparison::compare($sample[$this->column], $this->value, $this->operator)) {
310
            return $this->binaryLabels[0];
311
        }
312
313
        return $this->binaryLabels[1];
314
    }
315
316
    protected function resetBinary(): void
317
    {
318
    }
319
}
320