Passed
Pull Request — master (#130)
by Marcin
07:14
created

DecisionStump::evaluate()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 5
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 5
rs 9.4285
c 0
b 0
f 0
cc 1
eloc 3
nc 1
nop 3
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\OneVsRest;
9
use Phpml\Classification\WeightedClassifier;
10
use Phpml\Classification\DecisionTree;
11
use Phpml\Strategy\Compare\CompareStrategyFactory;
12
13
class DecisionStump extends WeightedClassifier
14
{
15
    use Predictable, OneVsRest;
16
17
    const AUTO_SELECT = -1;
18
19
    /**
20
     * @var int
21
     */
22
    protected $givenColumnIndex;
23
24
    /**
25
     * @var array
26
     */
27
    protected $binaryLabels;
28
29
    /**
30
     * Lowest error rate obtained while training/optimizing the model
31
     *
32
     * @var float
33
     */
34
    protected $trainingErrorRate;
35
36
    /**
37
     * @var int
38
     */
39
    protected $column;
40
41
    /**
42
     * @var mixed
43
     */
44
    protected $value;
45
46
    /**
47
     * @var string
48
     */
49
    protected $operator;
50
51
    /**
52
     * @var array
53
     */
54
    protected $columnTypes;
55
56
    /**
57
     * @var int
58
     */
59
    protected $featureCount;
60
61
    /**
62
     * @var float
63
     */
64
    protected $numSplitCount = 100.0;
65
66
    /**
67
     * Distribution of samples in the leaves
68
     *
69
     * @var array
70
     */
71
    protected $prob;
72
73
    /**
74
     * A DecisionStump classifier is a one-level deep DecisionTree. It is generally
75
     * used with ensemble algorithms as in the weak classifier role. <br>
76
     *
77
     * If columnIndex is given, then the stump tries to produce a decision node
78
     * on this column, otherwise in cases given the value of -1, the stump itself
79
     * decides which column to take for the decision (Default DecisionTree behaviour)
80
     *
81
     * @param int $columnIndex
82
     */
83
    public function __construct(int $columnIndex = self::AUTO_SELECT)
84
    {
85
        $this->givenColumnIndex = $columnIndex;
86
    }
87
88
    /**
89
     * @param array $samples
90
     * @param array $targets
91
     * @param array $labels
92
     *
93
     * @throws \Exception
94
     */
95
    protected function trainBinary(array $samples, array $targets, array $labels)
96
    {
97
        $this->binaryLabels = $labels;
98
        $this->featureCount = count($samples[0]);
99
100
        // If a column index is given, it should be among the existing columns
101
        if ($this->givenColumnIndex > count($samples[0]) - 1) {
102
            $this->givenColumnIndex = self::AUTO_SELECT;
103
        }
104
105
        // Check the size of the weights given.
106
        // If none given, then assign 1 as a weight to each sample
107
        if ($this->weights) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $this->weights of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
108
            $numWeights = count($this->weights);
109
            if ($numWeights != count($samples)) {
110
                throw new \Exception('Number of sample weights does not match with number of samples');
111
            }
112
        } else {
113
            $this->weights = array_fill(0, count($samples), 1);
114
        }
115
116
        // Determine type of each column as either "continuous" or "nominal"
117
        $this->columnTypes = DecisionTree::getColumnTypes($samples);
118
119
        // Try to find the best split in the columns of the dataset
120
        // by calculating error rate for each split point in each column
121
        $columns = range(0, count($samples[0]) - 1);
122
        if ($this->givenColumnIndex != self::AUTO_SELECT) {
123
            $columns = [$this->givenColumnIndex];
124
        }
125
126
        $bestSplit = [
127
            'value' => 0, 'operator' => '',
128
            'prob' => [], 'column' => 0,
129
            'trainingErrorRate' => 1.0];
130
        foreach ($columns as $col) {
131
            if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) {
132
                $split = $this->getBestNumericalSplit($samples, $targets, $col);
133
            } else {
134
                $split = $this->getBestNominalSplit($samples, $targets, $col);
135
            }
136
137
            if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) {
138
                $bestSplit = $split;
139
            }
140
        }
141
142
        // Assign determined best values to the stump
143
        foreach ($bestSplit as $name => $value) {
144
            $this->{$name} = $value;
145
        }
146
    }
147
148
    /**
149
     * While finding best split point for a numerical valued column,
150
     * DecisionStump looks for equally distanced values between minimum and maximum
151
     * values in the column. Given <i>$count</i> value determines how many split
152
     * points to be probed. The more split counts, the better performance but
153
     * worse processing time (Default value is 10.0)
154
     *
155
     * @param float $count
156
     */
157
    public function setNumericalSplitCount(float $count)
158
    {
159
        $this->numSplitCount = $count;
160
    }
161
162
    /**
163
     * Determines best split point for the given column
164
     *
165
     * @param array $samples
166
     * @param array $targets
167
     * @param int   $col
168
     *
169
     * @return array
170
     */
171
    protected function getBestNumericalSplit(array $samples, array $targets, int $col)
172
    {
173
        $values = array_column($samples, $col);
174
        // Trying all possible points may be accomplished in two general ways:
175
        // 1- Try all values in the $samples array ($values)
176
        // 2- Artificially split the range of values into several parts and try them
177
        // We choose the second one because it is faster in larger datasets
178
        $minValue = min($values);
179
        $maxValue = max($values);
180
        $stepSize = ($maxValue - $minValue) / $this->numSplitCount;
181
182
        $split = null;
183
184
        foreach (['<=', '>'] as $operator) {
185
            // Before trying all possible split points, let's first try
186
            // the average value for the cut point
187
            $threshold = array_sum($values) / (float) count($values);
188
            list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values);
189 View Code Duplication
            if ($split == null || $errorRate < $split['trainingErrorRate']) {
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated across your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
190
                $split = ['value' => $threshold, 'operator' => $operator,
191
                        'prob' => $prob, 'column' => $col,
192
                        'trainingErrorRate' => $errorRate];
193
            }
194
195
            // Try other possible points one by one
196
            for ($step = $minValue; $step <= $maxValue; $step += $stepSize) {
197
                $threshold = (float) $step;
198
                list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values);
199 View Code Duplication
                if ($errorRate < $split['trainingErrorRate']) {
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated across your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
200
                    $split = ['value' => $threshold, 'operator' => $operator,
201
                        'prob' => $prob, 'column' => $col,
202
                        'trainingErrorRate' => $errorRate];
203
                }
204
            }// for
205
        }
206
207
        return $split;
208
    }
209
210
    /**
211
     * @param array $samples
212
     * @param array $targets
213
     * @param int   $col
214
     *
215
     * @return array
216
     */
217
    protected function getBestNominalSplit(array $samples, array $targets, int $col) : array
218
    {
219
        $values = array_column($samples, $col);
220
        $valueCounts = array_count_values($values);
221
        $distinctVals = array_keys($valueCounts);
222
223
        $split = null;
224
225
        foreach (['=', '!='] as $operator) {
226
            foreach ($distinctVals as $val) {
227
                list($errorRate, $prob) = $this->calculateErrorRate($targets, $val, $operator, $values);
228
229 View Code Duplication
                if ($split == null || $split['trainingErrorRate'] < $errorRate) {
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated across your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
230
                    $split = ['value' => $val, 'operator' => $operator,
231
                        'prob' => $prob, 'column' => $col,
232
                        'trainingErrorRate' => $errorRate];
233
                }
234
            }
235
        }
236
237
        return $split;
238
    }
239
240
    /**
241
     *
242
     * @param mixed  $leftValue
243
     * @param string $operator
244
     * @param mixed  $rightValue
245
     *
246
     * @return boolean
247
     */
248
    protected function evaluate($leftValue, string $operator, $rightValue)
249
    {
250
        return CompareStrategyFactory::create($operator)
251
            ->compare($leftValue, $rightValue);
252
    }
253
254
    /**
255
     * Calculates the ratio of wrong predictions based on the new threshold
256
     * value given as the parameter
257
     *
258
     * @param array  $targets
259
     * @param float  $threshold
260
     * @param string $operator
261
     * @param array  $values
262
     *
263
     * @return array
264
     */
265
    protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values) : array
266
    {
267
        $wrong = 0.0;
268
        $prob = [];
269
        $leftLabel = $this->binaryLabels[0];
270
        $rightLabel = $this->binaryLabels[1];
271
272
        foreach ($values as $index => $value) {
273
            if ($this->evaluate($value, $operator, $threshold)) {
274
                $predicted = $leftLabel;
275
            } else {
276
                $predicted = $rightLabel;
277
            }
278
279
            $target = $targets[$index];
280
            if ((string) $predicted != (string) $targets[$index]) {
281
                $wrong += $this->weights[$index];
282
            }
283
284
            if (!isset($prob[$predicted][$target])) {
285
                $prob[$predicted][$target] = 0;
286
            }
287
            ++$prob[$predicted][$target];
288
        }
289
290
        // Calculate probabilities: Proportion of labels in each leaf
291
        $dist = array_combine($this->binaryLabels, array_fill(0, 2, 0.0));
292
        foreach ($prob as $leaf => $counts) {
293
            $leafTotal = (float) array_sum($prob[$leaf]);
294
            foreach ($counts as $label => $count) {
295
                if ((string) $leaf == (string) $label) {
296
                    $dist[$leaf] = $count / $leafTotal;
297
                }
298
            }
299
        }
300
301
        return [$wrong / (float) array_sum($this->weights), $dist];
302
    }
303
304
    /**
305
     * Returns the probability of the sample of belonging to the given label
306
     *
307
     * Probability of a sample is calculated as the proportion of the label
308
     * within the labels of the training samples in the decision node
309
     *
310
     * @param array $sample
311
     * @param mixed $label
312
     *
313
     * @return float
314
     */
315
    protected function predictProbability(array $sample, $label) : float
316
    {
317
        $predicted = $this->predictSampleBinary($sample);
318
        if ((string) $predicted == (string) $label) {
319
            return $this->prob[$label];
320
        }
321
322
        return 0.0;
323
    }
324
325
    /**
326
     * @param array $sample
327
     *
328
     * @return mixed
329
     */
330
    protected function predictSampleBinary(array $sample)
331
    {
332
        if ($this->evaluate($sample[$this->column], $this->operator, $this->value)) {
333
            return $this->binaryLabels[0];
334
        }
335
336
        return $this->binaryLabels[1];
337
    }
338
339
    /**
340
     * @return void
341
     */
342
    protected function resetBinary()
343
    {
344
    }
345
346
    /**
347
     * @return string
348
     */
349
    public function __toString()
350
    {
351
        return "IF $this->column $this->operator $this->value ".
352
            'THEN '.$this->binaryLabels[0].' '.
353
            'ELSE '.$this->binaryLabels[1];
354
    }
355
}
356