Passed
Push — master ( e83f7b...d953ef )
by Arkadiusz
03:28
created

src/Phpml/Classification/Linear/DecisionStump.php (1 issue)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Exception;
8
use Phpml\Classification\DecisionTree;
9
use Phpml\Classification\WeightedClassifier;
10
use Phpml\Helper\OneVsRest;
11
use Phpml\Helper\Predictable;
12
use Phpml\Math\Comparison;
13
14
class DecisionStump extends WeightedClassifier
15
{
16
    use Predictable, OneVsRest;
17
18
    public const AUTO_SELECT = -1;
19
20
    /**
21
     * @var int
22
     */
23
    protected $givenColumnIndex;
24
25
    /**
26
     * @var array
27
     */
28
    protected $binaryLabels = [];
29
30
    /**
31
     * Lowest error rate obtained while training/optimizing the model
32
     *
33
     * @var float
34
     */
35
    protected $trainingErrorRate;
36
37
    /**
38
     * @var int
39
     */
40
    protected $column;
41
42
    /**
43
     * @var mixed
44
     */
45
    protected $value;
46
47
    /**
48
     * @var string
49
     */
50
    protected $operator;
51
52
    /**
53
     * @var array
54
     */
55
    protected $columnTypes = [];
56
57
    /**
58
     * @var int
59
     */
60
    protected $featureCount;
61
62
    /**
63
     * @var float
64
     */
65
    protected $numSplitCount = 100.0;
66
67
    /**
68
     * Distribution of samples in the leaves
69
     *
70
     * @var array
71
     */
72
    protected $prob = [];
73
74
    /**
75
     * A DecisionStump classifier is a one-level deep DecisionTree. It is generally
76
     * used with ensemble algorithms as in the weak classifier role. <br>
77
     *
78
     * If columnIndex is given, then the stump tries to produce a decision node
79
     * on this column, otherwise in cases given the value of -1, the stump itself
80
     * decides which column to take for the decision (Default DecisionTree behaviour)
81
     */
82
    public function __construct(int $columnIndex = self::AUTO_SELECT)
83
    {
84
        $this->givenColumnIndex = $columnIndex;
85
    }
86
87
    public function __toString(): string
88
    {
89
        return "IF ${this}->column ${this}->operator ${this}->value ".
90
            'THEN '.$this->binaryLabels[0].' '.
91
            'ELSE '.$this->binaryLabels[1];
92
    }
93
94
    /**
95
     * While finding best split point for a numerical valued column,
96
     * DecisionStump looks for equally distanced values between minimum and maximum
97
     * values in the column. Given <i>$count</i> value determines how many split
98
     * points to be probed. The more split counts, the better performance but
99
     * worse processing time (Default value is 10.0)
100
     */
101
    public function setNumericalSplitCount(float $count): void
102
    {
103
        $this->numSplitCount = $count;
104
    }
105
106
    /**
107
     * @throws \Exception
108
     */
109
    protected function trainBinary(array $samples, array $targets, array $labels): void
110
    {
111
        $this->binaryLabels = $labels;
112
        $this->featureCount = count($samples[0]);
113
114
        // If a column index is given, it should be among the existing columns
115
        if ($this->givenColumnIndex > count($samples[0]) - 1) {
116
            $this->givenColumnIndex = self::AUTO_SELECT;
117
        }
118
119
        // Check the size of the weights given.
120
        // If none given, then assign 1 as a weight to each sample
121
        if ($this->weights) {
0 ignored issues
show
Bug Best Practice introduced by Mustafa Karabulut
The expression $this->weights of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
122
            $numWeights = count($this->weights);
123
            if ($numWeights != count($samples)) {
124
                throw new Exception('Number of sample weights does not match with number of samples');
125
            }
126
        } else {
127
            $this->weights = array_fill(0, count($samples), 1);
128
        }
129
130
        // Determine type of each column as either "continuous" or "nominal"
131
        $this->columnTypes = DecisionTree::getColumnTypes($samples);
132
133
        // Try to find the best split in the columns of the dataset
134
        // by calculating error rate for each split point in each column
135
        $columns = range(0, count($samples[0]) - 1);
136
        if ($this->givenColumnIndex != self::AUTO_SELECT) {
137
            $columns = [$this->givenColumnIndex];
138
        }
139
140
        $bestSplit = [
141
            'value' => 0,
142
            'operator' => '',
143
            'prob' => [],
144
            'column' => 0,
145
            'trainingErrorRate' => 1.0,
146
        ];
147
        foreach ($columns as $col) {
148
            if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) {
149
                $split = $this->getBestNumericalSplit($samples, $targets, $col);
150
            } else {
151
                $split = $this->getBestNominalSplit($samples, $targets, $col);
152
            }
153
154
            if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) {
155
                $bestSplit = $split;
156
            }
157
        }
158
159
        // Assign determined best values to the stump
160
        foreach ($bestSplit as $name => $value) {
161
            $this->{$name} = $value;
162
        }
163
    }
164
165
    /**
166
     * Determines best split point for the given column
167
     */
168
    protected function getBestNumericalSplit(array $samples, array $targets, int $col): array
169
    {
170
        $values = array_column($samples, $col);
171
        // Trying all possible points may be accomplished in two general ways:
172
        // 1- Try all values in the $samples array ($values)
173
        // 2- Artificially split the range of values into several parts and try them
174
        // We choose the second one because it is faster in larger datasets
175
        $minValue = min($values);
176
        $maxValue = max($values);
177
        $stepSize = ($maxValue - $minValue) / $this->numSplitCount;
178
179
        $split = [];
180
181
        foreach (['<=', '>'] as $operator) {
182
            // Before trying all possible split points, let's first try
183
            // the average value for the cut point
184
            $threshold = array_sum($values) / (float) count($values);
185
            [$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values);
186
            if ($split === [] || $errorRate < $split['trainingErrorRate']) {
187
                $split = [
188
                    'value' => $threshold,
189
                    'operator' => $operator,
190
                    'prob' => $prob,
191
                    'column' => $col,
192
                    'trainingErrorRate' => $errorRate,
193
                ];
194
            }
195
196
            // Try other possible points one by one
197
            for ($step = $minValue; $step <= $maxValue; $step += $stepSize) {
198
                $threshold = (float) $step;
199
                [$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values);
200
                if ($errorRate < $split['trainingErrorRate']) {
201
                    $split = [
202
                        'value' => $threshold,
203
                        'operator' => $operator,
204
                        'prob' => $prob,
205
                        'column' => $col,
206
                        'trainingErrorRate' => $errorRate,
207
                    ];
208
                }
209
            }// for
210
        }
211
212
        return $split;
213
    }
214
215
    protected function getBestNominalSplit(array $samples, array $targets, int $col): array
216
    {
217
        $values = array_column($samples, $col);
218
        $valueCounts = array_count_values($values);
219
        $distinctVals = array_keys($valueCounts);
220
221
        $split = [];
222
223
        foreach (['=', '!='] as $operator) {
224
            foreach ($distinctVals as $val) {
225
                [$errorRate, $prob] = $this->calculateErrorRate($targets, $val, $operator, $values);
226
227
                if ($split === [] || $split['trainingErrorRate'] < $errorRate) {
228
                    $split = [
229
                        'value' => $val,
230
                        'operator' => $operator,
231
                        'prob' => $prob,
232
                        'column' => $col,
233
                        'trainingErrorRate' => $errorRate,
234
                    ];
235
                }
236
            }
237
        }
238
239
        return $split;
240
    }
241
242
    /**
243
     * Calculates the ratio of wrong predictions based on the new threshold
244
     * value given as the parameter
245
     */
246
    protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values): array
247
    {
248
        $wrong = 0.0;
249
        $prob = [];
250
        $leftLabel = $this->binaryLabels[0];
251
        $rightLabel = $this->binaryLabels[1];
252
253
        foreach ($values as $index => $value) {
254
            if (Comparison::compare($value, $threshold, $operator)) {
255
                $predicted = $leftLabel;
256
            } else {
257
                $predicted = $rightLabel;
258
            }
259
260
            $target = $targets[$index];
261
            if ((string) $predicted != (string) $targets[$index]) {
262
                $wrong += $this->weights[$index];
263
            }
264
265
            if (!isset($prob[$predicted][$target])) {
266
                $prob[$predicted][$target] = 0;
267
            }
268
269
            ++$prob[$predicted][$target];
270
        }
271
272
        // Calculate probabilities: Proportion of labels in each leaf
273
        $dist = array_combine($this->binaryLabels, array_fill(0, 2, 0.0));
274
        foreach ($prob as $leaf => $counts) {
275
            $leafTotal = (float) array_sum($prob[$leaf]);
276
            foreach ($counts as $label => $count) {
277
                if ((string) $leaf == (string) $label) {
278
                    $dist[$leaf] = $count / $leafTotal;
279
                }
280
            }
281
        }
282
283
        return [$wrong / (float) array_sum($this->weights), $dist];
284
    }
285
286
    /**
287
     * Returns the probability of the sample of belonging to the given label
288
     *
289
     * Probability of a sample is calculated as the proportion of the label
290
     * within the labels of the training samples in the decision node
291
     *
292
     * @param mixed $label
293
     */
294
    protected function predictProbability(array $sample, $label): float
295
    {
296
        $predicted = $this->predictSampleBinary($sample);
297
        if ((string) $predicted == (string) $label) {
298
            return $this->prob[$label];
299
        }
300
301
        return 0.0;
302
    }
303
304
    /**
305
     * @return mixed
306
     */
307
    protected function predictSampleBinary(array $sample)
308
    {
309
        if (Comparison::compare($sample[$this->column], $this->value, $this->operator)) {
310
            return $this->binaryLabels[0];
311
        }
312
313
        return $this->binaryLabels[1];
314
    }
315
316
    protected function resetBinary(): void
317
    {
318
    }
319
}
320