Completed
Push — master ( cf222b...4daa0a )
by Arkadiusz
03:24
created

DecisionStump::optimizeDecision()   B

Complexity

Conditions 4
Paths 4

Size

Total Lines 35
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 35
rs 8.5806
c 0
b 0
f 0
cc 4
eloc 24
nc 4
nop 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\Trainable;
9
use Phpml\Classification\Classifier;
10
use Phpml\Classification\DecisionTree;
11
use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
12
13
class DecisionStump extends DecisionTree
14
{
15
    use Trainable, Predictable;
16
17
    /**
18
     * @var int
19
     */
20
    protected $columnIndex;
21
22
23
    /**
24
     * Sample weights : If used the optimization on the decision value
25
     * will take these weights into account. If not given, all samples
26
     * will be weighed with the same value of 1
27
     *
28
     * @var array
29
     */
30
    protected $weights = null;
31
32
    /**
33
     * Lowest error rate obtained while training/optimizing the model
34
     *
35
     * @var float
36
     */
37
    protected $trainingErrorRate;
38
39
    /**
40
     * A DecisionStump classifier is a one-level deep DecisionTree. It is generally
41
     * used with ensemble algorithms as in the weak classifier role. <br>
42
     *
43
     * If columnIndex is given, then the stump tries to produce a decision node
44
     * on this column, otherwise in cases given the value of -1, the stump itself
45
     * decides which column to take for the decision (Default DecisionTree behaviour)
46
     *
47
     * @param int $columnIndex
48
     */
49
    public function __construct(int $columnIndex = -1)
50
    {
51
        $this->columnIndex = $columnIndex;
52
53
        parent::__construct(1);
54
    }
55
56
    /**
57
     * @param array $samples
58
     * @param array $targets
59
     */
60
    public function train(array $samples, array $targets)
61
    {
62
        if ($this->columnIndex > count($samples[0]) - 1) {
63
            $this->columnIndex = -1;
64
        }
65
66
        if ($this->columnIndex >= 0) {
67
            $this->setSelectedFeatures([$this->columnIndex]);
68
        }
69
70
        if ($this->weights) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $this->weights of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
71
            $numWeights = count($this->weights);
72
            if ($numWeights != count($samples)) {
73
                throw new \Exception("Number of sample weights does not match with number of samples");
74
            }
75
        } else {
76
            $this->weights = array_fill(0, count($samples), 1);
77
        }
78
79
        parent::train($samples, $targets);
80
81
        $this->columnIndex = $this->tree->columnIndex;
82
83
        // For numerical values, try to optimize the value by finding a different threshold value
84
        if ($this->columnTypes[$this->columnIndex] == self::CONTINUOS) {
85
            $this->optimizeDecision($samples, $targets);
86
        }
87
    }
88
89
    /**
90
     * Used to set sample weights.
91
     *
92
     * @param array $weights
93
     */
94
    public function setSampleWeights(array $weights)
95
    {
96
        $this->weights = $weights;
97
    }
98
99
    /**
100
     * Returns the training error rate, the proportion of wrong predictions
101
     * over the total number of samples
102
     *
103
     * @return float
104
     */
105
    public function getTrainingErrorRate()
106
    {
107
        return $this->trainingErrorRate;
108
    }
109
110
    /**
111
     * Tries to optimize the threshold by probing a range of different values
112
     * between the minimum and maximum values in the selected column
113
     *
114
     * @param array $samples
115
     * @param array $targets
116
     */
117
    protected function optimizeDecision(array $samples, array $targets)
118
    {
119
        $values = array_column($samples, $this->columnIndex);
120
        $minValue = min($values);
121
        $maxValue = max($values);
122
        $stepSize = ($maxValue - $minValue) / 100.0;
123
124
        $leftLabel = $this->tree->leftLeaf->classValue;
125
        $rightLabel= $this->tree->rightLeaf->classValue;
126
127
        $bestOperator = $this->tree->operator;
128
        $bestThreshold = $this->tree->numericValue;
129
        $bestErrorRate = $this->calculateErrorRate(
130
                $bestThreshold, $bestOperator, $values, $targets, $leftLabel, $rightLabel);
131
132
        foreach (['<=', '>'] as $operator) {
133
            for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) {
134
                $threshold = (float)$step;
135
                $errorRate = $this->calculateErrorRate(
136
                    $threshold, $operator, $values, $targets, $leftLabel, $rightLabel);
137
138
                if ($errorRate < $bestErrorRate) {
139
                    $bestErrorRate = $errorRate;
140
                    $bestThreshold = $threshold;
141
                    $bestOperator = $operator;
142
                }
143
            }// for
144
        }
145
146
        // Update the tree node value
147
        $this->tree->numericValue = $bestThreshold;
148
        $this->tree->operator = $bestOperator;
149
        $this->tree->value = "$bestOperator $bestThreshold";
150
        $this->trainingErrorRate = $bestErrorRate;
151
    }
152
153
    /**
154
     * Calculates the ratio of wrong predictions based on the new threshold
155
     * value given as the parameter
156
     *
157
     * @param float $threshold
158
     * @param string $operator
159
     * @param array $values
160
     * @param array $targets
161
     * @param mixed $leftLabel
162
     * @param mixed $rightLabel
163
     */
164
    protected function calculateErrorRate(float $threshold, string $operator, array $values, array $targets, $leftLabel, $rightLabel)
0 ignored issues
show
Unused Code introduced by
The parameter $threshold is not used and could be removed.

This check looks from parameters that have been defined for a function or method, but which are not used in the method body.

Loading history...
Unused Code introduced by
The parameter $leftLabel is not used and could be removed.

This check looks from parameters that have been defined for a function or method, but which are not used in the method body.

Loading history...
Unused Code introduced by
The parameter $rightLabel is not used and could be removed.

This check looks from parameters that have been defined for a function or method, but which are not used in the method body.

Loading history...
165
    {
166
        $total = (float) array_sum($this->weights);
167
        $wrong = 0;
168
169
        foreach ($values as $index => $value) {
170
            eval("\$predicted = \$value $operator \$threshold ? \$leftLabel : \$rightLabel;");
0 ignored issues
show
Coding Style introduced by
It is generally not recommended to use eval unless absolutely required.

On one hand, eval might be exploited by malicious users if they somehow manage to inject dynamic content. On the other hand, with the emergence of faster PHP runtimes like the HHVM, eval prevents some optimization that they perform.

Loading history...
171
172
            if ($predicted != $targets[$index]) {
0 ignored issues
show
Bug introduced by
The variable $predicted does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
173
                $wrong += $this->weights[$index];
174
            }
175
        }
176
177
        return $wrong / $total;
178
    }
179
}
180