DecisionStump   A
last analyzed

Complexity

Total Complexity 36

Size/Duplication

Total Lines 306
Duplicated Lines 0 %

Coupling/Cohesion

Components 1
Dependencies 6

Importance

Changes 0
Metric Value
wmc 36
lcom 1
cbo 6
dl 0
loc 306
rs 9.52
c 0
b 0
f 0

10 Methods

Rating   Name   Duplication   Size   Complexity  
B calculateErrorRate() 0 39 8
A predictProbability() 0 9 2
A predictSampleBinary() 0 8 2
A resetBinary() 0 3 1
A __construct() 0 4 1
A __toString() 0 6 1
A setNumericalSplitCount() 0 4 1
B trainBinary() 0 55 9
B getBestNumericalSplit() 0 46 6
A getBestNominalSplit() 0 25 5
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Phpml\Classification\DecisionTree;
8
use Phpml\Classification\WeightedClassifier;
9
use Phpml\Exception\InvalidArgumentException;
10
use Phpml\Helper\OneVsRest;
11
use Phpml\Helper\Predictable;
12
use Phpml\Math\Comparison;
13
14
class DecisionStump extends WeightedClassifier
15
{
16
    use Predictable;
17
    use OneVsRest;
18
19
    public const AUTO_SELECT = -1;
20
21
    /**
22
     * @var int
23
     */
24
    protected $givenColumnIndex;
25
26
    /**
27
     * @var array
28
     */
29
    protected $binaryLabels = [];
30
31
    /**
32
     * Lowest error rate obtained while training/optimizing the model
33
     *
34
     * @var float
35
     */
36
    protected $trainingErrorRate;
37
38
    /**
39
     * @var int
40
     */
41
    protected $column;
42
43
    /**
44
     * @var mixed
45
     */
46
    protected $value;
47
48
    /**
49
     * @var string
50
     */
51
    protected $operator;
52
53
    /**
54
     * @var array
55
     */
56
    protected $columnTypes = [];
57
58
    /**
59
     * @var int
60
     */
61
    protected $featureCount;
62
63
    /**
64
     * @var float
65
     */
66
    protected $numSplitCount = 100.0;
67
68
    /**
69
     * Distribution of samples in the leaves
70
     *
71
     * @var array
72
     */
73
    protected $prob = [];
74
75
    /**
76
     * A DecisionStump classifier is a one-level deep DecisionTree. It is generally
77
     * used with ensemble algorithms as in the weak classifier role. <br>
78
     *
79
     * If columnIndex is given, then the stump tries to produce a decision node
80
     * on this column, otherwise in cases given the value of -1, the stump itself
81
     * decides which column to take for the decision (Default DecisionTree behaviour)
82
     */
83
    public function __construct(int $columnIndex = self::AUTO_SELECT)
84
    {
85
        $this->givenColumnIndex = $columnIndex;
86
    }
87
88
    public function __toString(): string
89
    {
90
        return "IF ${this}->column ${this}->operator ${this}->value ".
91
            'THEN '.$this->binaryLabels[0].' '.
92
            'ELSE '.$this->binaryLabels[1];
93
    }
94
95
    /**
96
     * While finding best split point for a numerical valued column,
97
     * DecisionStump looks for equally distanced values between minimum and maximum
98
     * values in the column. Given <i>$count</i> value determines how many split
99
     * points to be probed. The more split counts, the better performance but
100
     * worse processing time (Default value is 10.0)
101
     */
102
    public function setNumericalSplitCount(float $count): void
103
    {
104
        $this->numSplitCount = $count;
105
    }
106
107
    /**
108
     * @throws InvalidArgumentException
109
     */
110
    protected function trainBinary(array $samples, array $targets, array $labels): void
111
    {
112
        $this->binaryLabels = $labels;
113
        $this->featureCount = count($samples[0]);
114
115
        // If a column index is given, it should be among the existing columns
116
        if ($this->givenColumnIndex > count($samples[0]) - 1) {
117
            $this->givenColumnIndex = self::AUTO_SELECT;
118
        }
119
120
        // Check the size of the weights given.
121
        // If none given, then assign 1 as a weight to each sample
122
        if (count($this->weights) === 0) {
123
            $this->weights = array_fill(0, count($samples), 1);
124
        } else {
125
            $numWeights = count($this->weights);
126
            if ($numWeights !== count($samples)) {
127
                throw new InvalidArgumentException('Number of sample weights does not match with number of samples');
128
            }
129
        }
130
131
        // Determine type of each column as either "continuous" or "nominal"
132
        $this->columnTypes = DecisionTree::getColumnTypes($samples);
133
134
        // Try to find the best split in the columns of the dataset
135
        // by calculating error rate for each split point in each column
136
        $columns = range(0, count($samples[0]) - 1);
137
        if ($this->givenColumnIndex !== self::AUTO_SELECT) {
138
            $columns = [$this->givenColumnIndex];
139
        }
140
141
        $bestSplit = [
142
            'value' => 0,
143
            'operator' => '',
144
            'prob' => [],
145
            'column' => 0,
146
            'trainingErrorRate' => 1.0,
147
        ];
148
        foreach ($columns as $col) {
149
            if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) {
150
                $split = $this->getBestNumericalSplit($samples, $targets, $col);
151
            } else {
152
                $split = $this->getBestNominalSplit($samples, $targets, $col);
153
            }
154
155
            if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) {
156
                $bestSplit = $split;
157
            }
158
        }
159
160
        // Assign determined best values to the stump
161
        foreach ($bestSplit as $name => $value) {
162
            $this->{$name} = $value;
163
        }
164
    }
165
166
    /**
167
     * Determines best split point for the given column
168
     */
169
    protected function getBestNumericalSplit(array $samples, array $targets, int $col): array
170
    {
171
        $values = array_column($samples, $col);
172
        // Trying all possible points may be accomplished in two general ways:
173
        // 1- Try all values in the $samples array ($values)
174
        // 2- Artificially split the range of values into several parts and try them
175
        // We choose the second one because it is faster in larger datasets
176
        $minValue = min($values);
177
        $maxValue = max($values);
178
        $stepSize = ($maxValue - $minValue) / $this->numSplitCount;
179
180
        $split = [];
181
182
        foreach (['<=', '>'] as $operator) {
183
            // Before trying all possible split points, let's first try
184
            // the average value for the cut point
185
            $threshold = array_sum($values) / (float) count($values);
186
            [$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values);
0 ignored issues
show
Bug introduced by TomasVotruba
The variable $errorRate does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
Bug introduced by TomasVotruba
The variable $prob does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
187
            if (!isset($split['trainingErrorRate']) || $errorRate < $split['trainingErrorRate']) {
188
                $split = [
189
                    'value' => $threshold,
190
                    'operator' => $operator,
191
                    'prob' => $prob,
192
                    'column' => $col,
193
                    'trainingErrorRate' => $errorRate,
194
                ];
195
            }
196
197
            // Try other possible points one by one
198
            for ($step = $minValue; $step <= $maxValue; $step += $stepSize) {
199
                $threshold = (float) $step;
200
                [$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values);
201
                if ($errorRate < $split['trainingErrorRate']) {
202
                    $split = [
203
                        'value' => $threshold,
204
                        'operator' => $operator,
205
                        'prob' => $prob,
206
                        'column' => $col,
207
                        'trainingErrorRate' => $errorRate,
208
                    ];
209
                }
210
            }// for
211
        }
212
213
        return $split;
214
    }
215
216
    protected function getBestNominalSplit(array $samples, array $targets, int $col): array
217
    {
218
        $values = array_column($samples, $col);
219
        $valueCounts = array_count_values($values);
220
        $distinctVals = array_keys($valueCounts);
221
222
        $split = [];
223
224
        foreach (['=', '!='] as $operator) {
225
            foreach ($distinctVals as $val) {
226
                [$errorRate, $prob] = $this->calculateErrorRate($targets, $val, $operator, $values);
0 ignored issues
show
Bug introduced by TomasVotruba
The variable $errorRate does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
Bug introduced by TomasVotruba
The variable $prob does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
227
                if (!isset($split['trainingErrorRate']) || $split['trainingErrorRate'] < $errorRate) {
228
                    $split = [
229
                        'value' => $val,
230
                        'operator' => $operator,
231
                        'prob' => $prob,
232
                        'column' => $col,
233
                        'trainingErrorRate' => $errorRate,
234
                    ];
235
                }
236
            }
237
        }
238
239
        return $split;
240
    }
241
242
    /**
243
     * Calculates the ratio of wrong predictions based on the new threshold
244
     * value given as the parameter
245
     */
246
    protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values): array
247
    {
248
        $wrong = 0.0;
249
        $prob = [];
250
        $leftLabel = $this->binaryLabels[0];
251
        $rightLabel = $this->binaryLabels[1];
252
253
        foreach ($values as $index => $value) {
254
            if (Comparison::compare($value, $threshold, $operator)) {
255
                $predicted = $leftLabel;
256
            } else {
257
                $predicted = $rightLabel;
258
            }
259
260
            $target = $targets[$index];
261
            if ((string) $predicted != (string) $targets[$index]) {
262
                $wrong += $this->weights[$index];
263
            }
264
265
            if (!isset($prob[$predicted][$target])) {
266
                $prob[$predicted][$target] = 0;
267
            }
268
269
            ++$prob[$predicted][$target];
270
        }
271
272
        // Calculate probabilities: Proportion of labels in each leaf
273
        $dist = array_combine($this->binaryLabels, array_fill(0, 2, 0.0));
274
        foreach ($prob as $leaf => $counts) {
275
            $leafTotal = (float) array_sum($prob[$leaf]);
276
            foreach ($counts as $label => $count) {
277
                if ((string) $leaf == (string) $label) {
278
                    $dist[$leaf] = $count / $leafTotal;
279
                }
280
            }
281
        }
282
283
        return [$wrong / (float) array_sum($this->weights), $dist];
284
    }
285
286
    /**
287
     * Returns the probability of the sample of belonging to the given label
288
     *
289
     * Probability of a sample is calculated as the proportion of the label
290
     * within the labels of the training samples in the decision node
291
     *
292
     * @param mixed $label
293
     */
294
    protected function predictProbability(array $sample, $label): float
295
    {
296
        $predicted = $this->predictSampleBinary($sample);
297
        if ((string) $predicted == (string) $label) {
298
            return $this->prob[$label];
299
        }
300
301
        return 0.0;
302
    }
303
304
    /**
305
     * @return mixed
306
     */
307
    protected function predictSampleBinary(array $sample)
308
    {
309
        if (Comparison::compare($sample[$this->column], $this->value, $this->operator)) {
310
            return $this->binaryLabels[0];
311
        }
312
313
        return $this->binaryLabels[1];
314
    }
315
316
    protected function resetBinary(): void
317
    {
318
    }
319
}
320