DecisionStump::__construct()   A
last analyzed

Complexity

Conditions 1
Paths 1

Size

Total Lines 3
Code Lines 1

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nc 1
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Phpml\Classification\DecisionTree;
8
use Phpml\Classification\WeightedClassifier;
9
use Phpml\Exception\InvalidArgumentException;
10
use Phpml\Helper\OneVsRest;
11
use Phpml\Helper\Predictable;
12
use Phpml\Math\Comparison;
13
14
class DecisionStump extends WeightedClassifier
15
{
16
    use Predictable;
17
    use OneVsRest;
18
19
    public const AUTO_SELECT = -1;
20
21
    /**
22
     * @var int
23
     */
24
    protected $givenColumnIndex;
25
26
    /**
27
     * @var array
28
     */
29
    protected $binaryLabels = [];
30
31
    /**
32
     * Lowest error rate obtained while training/optimizing the model
33
     *
34
     * @var float
35
     */
36
    protected $trainingErrorRate;
37
38
    /**
39
     * @var int
40
     */
41
    protected $column;
42
43
    /**
44
     * @var mixed
45
     */
46
    protected $value;
47
48
    /**
49
     * @var string
50
     */
51
    protected $operator;
52
53
    /**
54
     * @var array
55
     */
56
    protected $columnTypes = [];
57
58
    /**
59
     * @var int
60
     */
61
    protected $featureCount;
62
63
    /**
64
     * @var float
65
     */
66
    protected $numSplitCount = 100.0;
67
68
    /**
69
     * Distribution of samples in the leaves
70
     *
71
     * @var array
72
     */
73
    protected $prob = [];
74
75
    /**
76
     * A DecisionStump classifier is a one-level deep DecisionTree. It is generally
77
     * used with ensemble algorithms as in the weak classifier role. <br>
78
     *
79
     * If columnIndex is given, then the stump tries to produce a decision node
80
     * on this column, otherwise in cases given the value of -1, the stump itself
81
     * decides which column to take for the decision (Default DecisionTree behaviour)
82
     */
83
    public function __construct(int $columnIndex = self::AUTO_SELECT)
84
    {
85
        $this->givenColumnIndex = $columnIndex;
86
    }
87
88
    public function __toString(): string
89
    {
90
        return "IF ${this}->column ${this}->operator ${this}->value ".
91
            'THEN '.$this->binaryLabels[0].' '.
92
            'ELSE '.$this->binaryLabels[1];
93
    }
94
95
    /**
96
     * While finding best split point for a numerical valued column,
97
     * DecisionStump looks for equally distanced values between minimum and maximum
98
     * values in the column. Given <i>$count</i> value determines how many split
99
     * points to be probed. The more split counts, the better performance but
100
     * worse processing time (Default value is 10.0)
101
     */
102
    public function setNumericalSplitCount(float $count): void
103
    {
104
        $this->numSplitCount = $count;
105
    }
106
107
    /**
108
     * @throws InvalidArgumentException
109
     */
110
    protected function trainBinary(array $samples, array $targets, array $labels): void
111
    {
112
        $this->binaryLabels = $labels;
113
        $this->featureCount = count($samples[0]);
114
115
        // If a column index is given, it should be among the existing columns
116
        if ($this->givenColumnIndex > count($samples[0]) - 1) {
117
            $this->givenColumnIndex = self::AUTO_SELECT;
118
        }
119
120
        // Check the size of the weights given.
121
        // If none given, then assign 1 as a weight to each sample
122
        if (count($this->weights) === 0) {
123
            $this->weights = array_fill(0, count($samples), 1);
124
        } else {
125
            $numWeights = count($this->weights);
126
            if ($numWeights !== count($samples)) {
127
                throw new InvalidArgumentException('Number of sample weights does not match with number of samples');
128
            }
129
        }
130
131
        // Determine type of each column as either "continuous" or "nominal"
132
        $this->columnTypes = DecisionTree::getColumnTypes($samples);
133
134
        // Try to find the best split in the columns of the dataset
135
        // by calculating error rate for each split point in each column
136
        $columns = range(0, count($samples[0]) - 1);
137
        if ($this->givenColumnIndex !== self::AUTO_SELECT) {
138
            $columns = [$this->givenColumnIndex];
139
        }
140
141
        $bestSplit = [
142
            'value' => 0,
143
            'operator' => '',
144
            'prob' => [],
145
            'column' => 0,
146
            'trainingErrorRate' => 1.0,
147
        ];
148
        foreach ($columns as $col) {
149
            if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) {
150
                $split = $this->getBestNumericalSplit($samples, $targets, $col);
151
            } else {
152
                $split = $this->getBestNominalSplit($samples, $targets, $col);
153
            }
154
155
            if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) {
156
                $bestSplit = $split;
157
            }
158
        }
159
160
        // Assign determined best values to the stump
161
        foreach ($bestSplit as $name => $value) {
162
            $this->{$name} = $value;
163
        }
164
    }
165
166
    /**
167
     * Determines best split point for the given column
168
     */
169
    protected function getBestNumericalSplit(array $samples, array $targets, int $col): array
170
    {
171
        $values = array_column($samples, $col);
172
        // Trying all possible points may be accomplished in two general ways:
173
        // 1- Try all values in the $samples array ($values)
174
        // 2- Artificially split the range of values into several parts and try them
175
        // We choose the second one because it is faster in larger datasets
176
        $minValue = min($values);
177
        $maxValue = max($values);
178
        $stepSize = ($maxValue - $minValue) / $this->numSplitCount;
179
180
        $split = [];
181
182
        foreach (['<=', '>'] as $operator) {
183
            // Before trying all possible split points, let's first try
184
            // the average value for the cut point
185
            $threshold = array_sum($values) / (float) count($values);
186
            [$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values);
187
            if (!isset($split['trainingErrorRate']) || $errorRate < $split['trainingErrorRate']) {
188
                $split = [
189
                    'value' => $threshold,
190
                    'operator' => $operator,
191
                    'prob' => $prob,
192
                    'column' => $col,
193
                    'trainingErrorRate' => $errorRate,
194
                ];
195
            }
196
197
            // Try other possible points one by one
198
            for ($step = $minValue; $step <= $maxValue; $step += $stepSize) {
199
                $threshold = (float) $step;
200
                [$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values);
201
                if ($errorRate < $split['trainingErrorRate']) {
202
                    $split = [
203
                        'value' => $threshold,
204
                        'operator' => $operator,
205
                        'prob' => $prob,
206
                        'column' => $col,
207
                        'trainingErrorRate' => $errorRate,
208
                    ];
209
                }
210
            }// for
211
        }
212
213
        return $split;
214
    }
215
216
    protected function getBestNominalSplit(array $samples, array $targets, int $col): array
217
    {
218
        $values = array_column($samples, $col);
219
        $valueCounts = array_count_values($values);
220
        $distinctVals = array_keys($valueCounts);
221
222
        $split = [];
223
224
        foreach (['=', '!='] as $operator) {
225
            foreach ($distinctVals as $val) {
226
                [$errorRate, $prob] = $this->calculateErrorRate($targets, $val, $operator, $values);
227
                if (!isset($split['trainingErrorRate']) || $split['trainingErrorRate'] < $errorRate) {
228
                    $split = [
229
                        'value' => $val,
230
                        'operator' => $operator,
231
                        'prob' => $prob,
232
                        'column' => $col,
233
                        'trainingErrorRate' => $errorRate,
234
                    ];
235
                }
236
            }
237
        }
238
239
        return $split;
240
    }
241
242
    /**
243
     * Calculates the ratio of wrong predictions based on the new threshold
244
     * value given as the parameter
245
     */
246
    protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values): array
247
    {
248
        $wrong = 0.0;
249
        $prob = [];
250
        $leftLabel = $this->binaryLabels[0];
251
        $rightLabel = $this->binaryLabels[1];
252
253
        foreach ($values as $index => $value) {
254
            if (Comparison::compare($value, $threshold, $operator)) {
255
                $predicted = $leftLabel;
256
            } else {
257
                $predicted = $rightLabel;
258
            }
259
260
            $target = $targets[$index];
261
            if ((string) $predicted != (string) $targets[$index]) {
262
                $wrong += $this->weights[$index];
263
            }
264
265
            if (!isset($prob[$predicted][$target])) {
266
                $prob[$predicted][$target] = 0;
267
            }
268
269
            ++$prob[$predicted][$target];
270
        }
271
272
        // Calculate probabilities: Proportion of labels in each leaf
273
        $dist = array_combine($this->binaryLabels, array_fill(0, 2, 0.0));
274
        foreach ($prob as $leaf => $counts) {
275
            $leafTotal = (float) array_sum($prob[$leaf]);
276
            foreach ($counts as $label => $count) {
277
                if ((string) $leaf == (string) $label) {
278
                    $dist[$leaf] = $count / $leafTotal;
279
                }
280
            }
281
        }
282
283
        return [$wrong / (float) array_sum($this->weights), $dist];
284
    }
285
286
    /**
287
     * Returns the probability of the sample of belonging to the given label
288
     *
289
     * Probability of a sample is calculated as the proportion of the label
290
     * within the labels of the training samples in the decision node
291
     *
292
     * @param mixed $label
293
     */
294
    protected function predictProbability(array $sample, $label): float
295
    {
296
        $predicted = $this->predictSampleBinary($sample);
297
        if ((string) $predicted == (string) $label) {
298
            return $this->prob[$label];
299
        }
300
301
        return 0.0;
302
    }
303
304
    /**
305
     * @return mixed
306
     */
307
    protected function predictSampleBinary(array $sample)
308
    {
309
        if (Comparison::compare($sample[$this->column], $this->value, $this->operator)) {
310
            return $this->binaryLabels[0];
311
        }
312
313
        return $this->binaryLabels[1];
314
    }
315
316
    protected function resetBinary(): void
317
    {
318
    }
319
}
320