Passed
Push — master ( 47cdff...ed5fc8 )
by Arkadiusz
03:38
created

src/Phpml/Classification/Linear/DecisionStump.php (16 issues)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\OneVsRest;
9
use Phpml\Classification\WeightedClassifier;
10
use Phpml\Classification\DecisionTree;
11
12
class DecisionStump extends WeightedClassifier
13
{
14
    use Predictable, OneVsRest;
15
16
    const AUTO_SELECT = -1;
17
18
    /**
19
     * @var int
20
     */
21
    protected $givenColumnIndex;
22
23
    /**
24
     * @var array
25
     */
26
    protected $binaryLabels;
27
28
    /**
29
     * Lowest error rate obtained while training/optimizing the model
30
     *
31
     * @var float
32
     */
33
    protected $trainingErrorRate;
34
35
    /**
36
     * @var int
37
     */
38
    protected $column;
39
40
    /**
41
     * @var mixed
42
     */
43
    protected $value;
44
45
    /**
46
     * @var string
47
     */
48
    protected $operator;
49
50
    /**
51
     * @var array
52
     */
53
    protected $columnTypes;
54
55
    /**
56
     * @var int
57
     */
58
    protected $featureCount;
59
60
    /**
61
     * @var float
62
     */
63
    protected $numSplitCount = 100.0;
64
65
    /**
66
     * Distribution of samples in the leaves
67
     *
68
     * @var array
69
     */
70
    protected $prob;
71
72
    /**
73
     * A DecisionStump classifier is a one-level deep DecisionTree. It is generally
74
     * used with ensemble algorithms as in the weak classifier role. <br>
75
     *
76
     * If columnIndex is given, then the stump tries to produce a decision node
77
     * on this column, otherwise in cases given the value of -1, the stump itself
78
     * decides which column to take for the decision (Default DecisionTree behaviour)
79
     *
80
     * @param int $columnIndex
81
     */
82
    public function __construct(int $columnIndex = self::AUTO_SELECT)
83
    {
84
        $this->givenColumnIndex = $columnIndex;
85
    }
86
87
    /**
88
     * @param array $samples
89
     * @param array $targets
90
     * @param array $labels
91
     *
92
     * @throws \Exception
93
     */
94
    protected function trainBinary(array $samples, array $targets, array $labels)
95
    {
96
        $this->binaryLabels = $labels;
97
        $this->featureCount = count($samples[0]);
98
99
        // If a column index is given, it should be among the existing columns
100
        if ($this->givenColumnIndex > count($samples[0]) - 1) {
101
            $this->givenColumnIndex = self::AUTO_SELECT;
102
        }
103
104
        // Check the size of the weights given.
105
        // If none given, then assign 1 as a weight to each sample
106
        if ($this->weights) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $this->weights of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
107
            $numWeights = count($this->weights);
108
            if ($numWeights != count($samples)) {
109
                throw new \Exception("Number of sample weights does not match with number of samples");
110
            }
111
        } else {
112
            $this->weights = array_fill(0, count($samples), 1);
113
        }
114
115
        // Determine type of each column as either "continuous" or "nominal"
116
        $this->columnTypes = DecisionTree::getColumnTypes($samples);
117
118
        // Try to find the best split in the columns of the dataset
119
        // by calculating error rate for each split point in each column
120
        $columns = range(0, count($samples[0]) - 1);
121
        if ($this->givenColumnIndex != self::AUTO_SELECT) {
122
            $columns = [$this->givenColumnIndex];
123
        }
124
125
        $bestSplit = [
126
            'value' => 0, 'operator' => '',
127
            'prob' => [], 'column' => 0,
128
            'trainingErrorRate' => 1.0];
129
        foreach ($columns as $col) {
130
            if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) {
131
                $split = $this->getBestNumericalSplit($samples, $targets, $col);
132
            } else {
133
                $split = $this->getBestNominalSplit($samples, $targets, $col);
134
            }
135
136
            if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) {
137
                $bestSplit = $split;
138
            }
139
        }
140
141
        // Assign determined best values to the stump
142
        foreach ($bestSplit as $name => $value) {
143
            $this->{$name} = $value;
144
        }
145
    }
146
147
    /**
148
     * While finding best split point for a numerical valued column,
149
     * DecisionStump looks for equally distanced values between minimum and maximum
150
     * values in the column. Given <i>$count</i> value determines how many split
151
     * points to be probed. The more split counts, the better performance but
152
     * worse processing time (Default value is 10.0)
153
     *
154
     * @param float $count
155
     */
156
    public function setNumericalSplitCount(float $count)
157
    {
158
        $this->numSplitCount = $count;
159
    }
160
161
    /**
162
     * Determines best split point for the given column
163
     *
164
     * @param array $samples
165
     * @param array $targets
166
     * @param int $col
167
     *
168
     * @return array
169
     */
170
    protected function getBestNumericalSplit(array $samples, array $targets, int $col)
171
    {
172
        $values = array_column($samples, $col);
173
        // Trying all possible points may be accomplished in two general ways:
174
        // 1- Try all values in the $samples array ($values)
175
        // 2- Artificially split the range of values into several parts and try them
176
        // We choose the second one because it is faster in larger datasets
177
        $minValue = min($values);
178
        $maxValue = max($values);
179
        $stepSize = ($maxValue - $minValue) / $this->numSplitCount;
180
181
        $split = null;
182
183
        foreach (['<=', '>'] as $operator) {
184
            // Before trying all possible split points, let's first try
185
            // the average value for the cut point
186
            $threshold = array_sum($values) / (float) count($values);
187
            list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values);
188 View Code Duplication
            if ($split == null || $errorRate < $split['trainingErrorRate']) {
0 ignored issues
show
This code seems to be duplicated across your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
189
                $split = ['value' => $threshold, 'operator' => $operator,
190
                        'prob' => $prob, 'column' => $col,
191
                        'trainingErrorRate' => $errorRate];
192
            }
193
194
            // Try other possible points one by one
195
            for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) {
196
                $threshold = (float)$step;
197
                list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values);
198 View Code Duplication
                if ($errorRate < $split['trainingErrorRate']) {
0 ignored issues
show
This code seems to be duplicated across your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
199
                    $split = ['value' => $threshold, 'operator' => $operator,
200
                        'prob' => $prob, 'column' => $col,
201
                        'trainingErrorRate' => $errorRate];
202
                }
203
            }// for
204
        }
205
206
        return $split;
207
    }
208
209
    /**
210
     * @param array $samples
211
     * @param array $targets
212
     * @param int $col
213
     *
214
     * @return array
215
     */
216
    protected function getBestNominalSplit(array $samples, array $targets, int $col) : array
217
    {
218
        $values = array_column($samples, $col);
219
        $valueCounts = array_count_values($values);
220
        $distinctVals= array_keys($valueCounts);
221
222
        $split = null;
223
224
        foreach (['=', '!='] as $operator) {
225
            foreach ($distinctVals as $val) {
226
                list($errorRate, $prob) = $this->calculateErrorRate($targets, $val, $operator, $values);
227
228 View Code Duplication
                if ($split == null || $split['trainingErrorRate'] < $errorRate) {
0 ignored issues
show
This code seems to be duplicated across your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
229
                    $split = ['value' => $val, 'operator' => $operator,
230
                        'prob' => $prob, 'column' => $col,
231
                        'trainingErrorRate' => $errorRate];
232
                }
233
            }
234
        }
235
236
        return $split;
237
    }
238
239
240
    /**
241
     *
242
     * @param mixed  $leftValue
243
     * @param string $operator
244
     * @param mixed  $rightValue
245
     *
246
     * @return boolean
247
     */
248
    protected function evaluate($leftValue, string $operator, $rightValue)
249
    {
250
        switch ($operator) {
251
            case '>': return $leftValue > $rightValue;
0 ignored issues
show
The case body in a switch statement must start on the line following the statement.

According to the PSR-2, the body of a case statement must start on the line immediately following the case statement.

switch ($expr) {
case "A":
    doSomething(); //right
    break;
case "B":

    doSomethingElse(); //wrong
    break;

}

To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.

Loading history...
Terminating statement must be on a line by itself

As per the PSR-2 coding standard, the break (or other terminating) statement must be on a line of its own.

switch ($expr) {
     case "A":
         doSomething();
         break; //wrong
     case "B":
         doSomething();
         break; //right
     case "C:":
         doSomething();
         return true; //right
 }

To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.

Loading history...
252
            case '>=': return $leftValue >= $rightValue;
0 ignored issues
show
The case body in a switch statement must start on the line following the statement.

According to the PSR-2, the body of a case statement must start on the line immediately following the case statement.

switch ($expr) {
case "A":
    doSomething(); //right
    break;
case "B":

    doSomethingElse(); //wrong
    break;

}

To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.

Loading history...
Terminating statement must be on a line by itself

As per the PSR-2 coding standard, the break (or other terminating) statement must be on a line of its own.

switch ($expr) {
     case "A":
         doSomething();
         break; //wrong
     case "B":
         doSomething();
         break; //right
     case "C:":
         doSomething();
         return true; //right
 }

To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.

Loading history...
253
            case '<': return $leftValue < $rightValue;
0 ignored issues
show
The case body in a switch statement must start on the line following the statement.

According to the PSR-2, the body of a case statement must start on the line immediately following the case statement.

switch ($expr) {
case "A":
    doSomething(); //right
    break;
case "B":

    doSomethingElse(); //wrong
    break;

}

To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.

Loading history...
Terminating statement must be on a line by itself

As per the PSR-2 coding standard, the break (or other terminating) statement must be on a line of its own.

switch ($expr) {
     case "A":
         doSomething();
         break; //wrong
     case "B":
         doSomething();
         break; //right
     case "C:":
         doSomething();
         return true; //right
 }

To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.

Loading history...
254
            case '<=': return $leftValue <= $rightValue;
0 ignored issues
show
The case body in a switch statement must start on the line following the statement.

According to the PSR-2, the body of a case statement must start on the line immediately following the case statement.

switch ($expr) {
case "A":
    doSomething(); //right
    break;
case "B":

    doSomethingElse(); //wrong
    break;

}

To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.

Loading history...
Terminating statement must be on a line by itself

As per the PSR-2 coding standard, the break (or other terminating) statement must be on a line of its own.

switch ($expr) {
     case "A":
         doSomething();
         break; //wrong
     case "B":
         doSomething();
         break; //right
     case "C:":
         doSomething();
         return true; //right
 }

To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.

Loading history...
255
            case '=': return $leftValue === $rightValue;
0 ignored issues
show
The case body in a switch statement must start on the line following the statement.

According to the PSR-2, the body of a case statement must start on the line immediately following the case statement.

switch ($expr) {
case "A":
    doSomething(); //right
    break;
case "B":

    doSomethingElse(); //wrong
    break;

}

To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.

Loading history...
Terminating statement must be on a line by itself

As per the PSR-2 coding standard, the break (or other terminating) statement must be on a line of its own.

switch ($expr) {
     case "A":
         doSomething();
         break; //wrong
     case "B":
         doSomething();
         break; //right
     case "C:":
         doSomething();
         return true; //right
 }

To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.

Loading history...
256
            case '!=':
257
            case '<>': return $leftValue !== $rightValue;
0 ignored issues
show
The case body in a switch statement must start on the line following the statement.

According to the PSR-2, the body of a case statement must start on the line immediately following the case statement.

switch ($expr) {
case "A":
    doSomething(); //right
    break;
case "B":

    doSomethingElse(); //wrong
    break;

}

To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.

Loading history...
Terminating statement must be on a line by itself

As per the PSR-2 coding standard, the break (or other terminating) statement must be on a line of its own.

switch ($expr) {
     case "A":
         doSomething();
         break; //wrong
     case "B":
         doSomething();
         break; //right
     case "C:":
         doSomething();
         return true; //right
 }

To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.

Loading history...
258
        }
259
260
        return false;
261
    }
262
263
    /**
264
     * Calculates the ratio of wrong predictions based on the new threshold
265
     * value given as the parameter
266
     *
267
     * @param array $targets
268
     * @param float $threshold
269
     * @param string $operator
270
     * @param array $values
271
     *
272
     * @return array
273
     */
274
    protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values) : array
275
    {
276
        $wrong = 0.0;
277
        $prob = [];
278
        $leftLabel = $this->binaryLabels[0];
279
        $rightLabel= $this->binaryLabels[1];
280
281
        foreach ($values as $index => $value) {
282
            if ($this->evaluate($value, $operator, $threshold)) {
283
                $predicted = $leftLabel;
284
            } else {
285
                $predicted = $rightLabel;
286
            }
287
288
            $target = $targets[$index];
289
            if (strval($predicted) != strval($targets[$index])) {
290
                $wrong += $this->weights[$index];
291
            }
292
293
            if (!isset($prob[$predicted][$target])) {
294
                $prob[$predicted][$target] = 0;
295
            }
296
            ++$prob[$predicted][$target];
297
        }
298
299
        // Calculate probabilities: Proportion of labels in each leaf
300
        $dist = array_combine($this->binaryLabels, array_fill(0, 2, 0.0));
301
        foreach ($prob as $leaf => $counts) {
302
            $leafTotal = (float)array_sum($prob[$leaf]);
303
            foreach ($counts as $label => $count) {
304
                if (strval($leaf) == strval($label)) {
305
                    $dist[$leaf] = $count / $leafTotal;
306
                }
307
            }
308
        }
309
310
        return [$wrong / (float) array_sum($this->weights), $dist];
311
    }
312
313
    /**
314
     * Returns the probability of the sample of belonging to the given label
315
     *
316
     * Probability of a sample is calculated as the proportion of the label
317
     * within the labels of the training samples in the decision node
318
     *
319
     * @param array $sample
320
     * @param mixed $label
321
     *
322
     * @return float
323
     */
324
    protected function predictProbability(array $sample, $label) : float
325
    {
326
        $predicted = $this->predictSampleBinary($sample);
327
        if (strval($predicted) == strval($label)) {
328
            return $this->prob[$label];
329
        }
330
331
        return 0.0;
332
    }
333
334
    /**
335
     * @param array $sample
336
     *
337
     * @return mixed
338
     */
339
    protected function predictSampleBinary(array $sample)
340
    {
341
        if ($this->evaluate($sample[$this->column], $this->operator, $this->value)) {
342
            return $this->binaryLabels[0];
343
        }
344
345
        return $this->binaryLabels[1];
346
    }
347
348
    /**
349
     * @return void
350
     */
351
    protected function resetBinary()
352
    {
353
    }
354
355
    /**
356
     * @return string
357
     */
358
    public function __toString()
359
    {
360
        return "IF $this->column $this->operator $this->value " .
361
            "THEN " . $this->binaryLabels[0] . " ".
362
            "ELSE " . $this->binaryLabels[1];
363
    }
364
}
365