1
|
|
|
<?php |
2
|
|
|
|
3
|
|
|
declare(strict_types=1); |
4
|
|
|
|
5
|
|
|
namespace Phpml\Classification\Linear; |
6
|
|
|
|
7
|
|
|
use Phpml\Helper\Predictable; |
8
|
|
|
use Phpml\Helper\Trainable; |
9
|
|
|
use Phpml\Classification\Classifier; |
10
|
|
|
use Phpml\Classification\DecisionTree; |
11
|
|
|
use Phpml\Classification\DecisionTree\DecisionTreeLeaf; |
12
|
|
|
|
13
|
|
|
class DecisionStump extends DecisionTree |
14
|
|
|
{ |
15
|
|
|
use Trainable, Predictable; |
16
|
|
|
|
17
|
|
|
/** |
18
|
|
|
* @var int |
19
|
|
|
*/ |
20
|
|
|
protected $columnIndex; |
21
|
|
|
|
22
|
|
|
|
23
|
|
|
/** |
24
|
|
|
* Sample weights : If used the optimization on the decision value |
25
|
|
|
* will take these weights into account. If not given, all samples |
26
|
|
|
* will be weighed with the same value of 1 |
27
|
|
|
* |
28
|
|
|
* @var array |
29
|
|
|
*/ |
30
|
|
|
protected $weights = null; |
31
|
|
|
|
32
|
|
|
/** |
33
|
|
|
* Lowest error rate obtained while training/optimizing the model |
34
|
|
|
* |
35
|
|
|
* @var float |
36
|
|
|
*/ |
37
|
|
|
protected $trainingErrorRate; |
38
|
|
|
|
39
|
|
|
/** |
40
|
|
|
* A DecisionStump classifier is a one-level deep DecisionTree. It is generally |
41
|
|
|
* used with ensemble algorithms as in the weak classifier role. <br> |
42
|
|
|
* |
43
|
|
|
* If columnIndex is given, then the stump tries to produce a decision node |
44
|
|
|
* on this column, otherwise in cases given the value of -1, the stump itself |
45
|
|
|
* decides which column to take for the decision (Default DecisionTree behaviour) |
46
|
|
|
* |
47
|
|
|
* @param int $columnIndex |
48
|
|
|
*/ |
49
|
|
|
public function __construct(int $columnIndex = -1) |
50
|
|
|
{ |
51
|
|
|
$this->columnIndex = $columnIndex; |
52
|
|
|
|
53
|
|
|
parent::__construct(1); |
54
|
|
|
} |
55
|
|
|
|
56
|
|
|
/** |
57
|
|
|
* @param array $samples |
58
|
|
|
* @param array $targets |
59
|
|
|
*/ |
60
|
|
|
public function train(array $samples, array $targets) |
61
|
|
|
{ |
62
|
|
|
if ($this->columnIndex > count($samples[0]) - 1) { |
63
|
|
|
$this->columnIndex = -1; |
64
|
|
|
} |
65
|
|
|
|
66
|
|
|
if ($this->columnIndex >= 0) { |
67
|
|
|
$this->setSelectedFeatures([$this->columnIndex]); |
68
|
|
|
} |
69
|
|
|
|
70
|
|
|
if ($this->weights) { |
|
|
|
|
71
|
|
|
$numWeights = count($this->weights); |
72
|
|
|
if ($numWeights != count($samples)) { |
73
|
|
|
throw new \Exception("Number of sample weights does not match with number of samples"); |
74
|
|
|
} |
75
|
|
|
} else { |
76
|
|
|
$this->weights = array_fill(0, count($samples), 1); |
77
|
|
|
} |
78
|
|
|
|
79
|
|
|
parent::train($samples, $targets); |
80
|
|
|
|
81
|
|
|
$this->columnIndex = $this->tree->columnIndex; |
82
|
|
|
|
83
|
|
|
// For numerical values, try to optimize the value by finding a different threshold value |
84
|
|
|
if ($this->columnTypes[$this->columnIndex] == self::CONTINUOS) { |
85
|
|
|
$this->optimizeDecision($samples, $targets); |
86
|
|
|
} |
87
|
|
|
} |
88
|
|
|
|
89
|
|
|
/** |
90
|
|
|
* Used to set sample weights. |
91
|
|
|
* |
92
|
|
|
* @param array $weights |
93
|
|
|
*/ |
94
|
|
|
public function setSampleWeights(array $weights) |
95
|
|
|
{ |
96
|
|
|
$this->weights = $weights; |
97
|
|
|
} |
98
|
|
|
|
99
|
|
|
/** |
100
|
|
|
* Returns the training error rate, the proportion of wrong predictions |
101
|
|
|
* over the total number of samples |
102
|
|
|
* |
103
|
|
|
* @return float |
104
|
|
|
*/ |
105
|
|
|
public function getTrainingErrorRate() |
106
|
|
|
{ |
107
|
|
|
return $this->trainingErrorRate; |
108
|
|
|
} |
109
|
|
|
|
110
|
|
|
/** |
111
|
|
|
* Tries to optimize the threshold by probing a range of different values |
112
|
|
|
* between the minimum and maximum values in the selected column |
113
|
|
|
* |
114
|
|
|
* @param array $samples |
115
|
|
|
* @param array $targets |
116
|
|
|
*/ |
117
|
|
|
protected function optimizeDecision(array $samples, array $targets) |
118
|
|
|
{ |
119
|
|
|
$values = array_column($samples, $this->columnIndex); |
120
|
|
|
$minValue = min($values); |
121
|
|
|
$maxValue = max($values); |
122
|
|
|
$stepSize = ($maxValue - $minValue) / 100.0; |
123
|
|
|
|
124
|
|
|
$leftLabel = $this->tree->leftLeaf->classValue; |
125
|
|
|
$rightLabel= $this->tree->rightLeaf->classValue; |
126
|
|
|
|
127
|
|
|
$bestOperator = $this->tree->operator; |
128
|
|
|
$bestThreshold = $this->tree->numericValue; |
129
|
|
|
$bestErrorRate = $this->calculateErrorRate( |
130
|
|
|
$bestThreshold, $bestOperator, $values, $targets, $leftLabel, $rightLabel); |
131
|
|
|
|
132
|
|
|
foreach (['<=', '>'] as $operator) { |
133
|
|
|
for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) { |
134
|
|
|
$threshold = (float)$step; |
135
|
|
|
$errorRate = $this->calculateErrorRate( |
136
|
|
|
$threshold, $operator, $values, $targets, $leftLabel, $rightLabel); |
137
|
|
|
|
138
|
|
|
if ($errorRate < $bestErrorRate) { |
139
|
|
|
$bestErrorRate = $errorRate; |
140
|
|
|
$bestThreshold = $threshold; |
141
|
|
|
$bestOperator = $operator; |
142
|
|
|
} |
143
|
|
|
}// for |
144
|
|
|
} |
145
|
|
|
|
146
|
|
|
// Update the tree node value |
147
|
|
|
$this->tree->numericValue = $bestThreshold; |
148
|
|
|
$this->tree->operator = $bestOperator; |
149
|
|
|
$this->tree->value = "$bestOperator $bestThreshold"; |
150
|
|
|
$this->trainingErrorRate = $bestErrorRate; |
151
|
|
|
} |
152
|
|
|
|
153
|
|
|
/** |
154
|
|
|
* Calculates the ratio of wrong predictions based on the new threshold |
155
|
|
|
* value given as the parameter |
156
|
|
|
* |
157
|
|
|
* @param float $threshold |
158
|
|
|
* @param string $operator |
159
|
|
|
* @param array $values |
160
|
|
|
* @param array $targets |
161
|
|
|
* @param mixed $leftLabel |
162
|
|
|
* @param mixed $rightLabel |
163
|
|
|
*/ |
164
|
|
|
protected function calculateErrorRate(float $threshold, string $operator, array $values, array $targets, $leftLabel, $rightLabel) |
|
|
|
|
165
|
|
|
{ |
166
|
|
|
$total = (float) array_sum($this->weights); |
167
|
|
|
$wrong = 0; |
168
|
|
|
|
169
|
|
|
foreach ($values as $index => $value) { |
170
|
|
|
eval("\$predicted = \$value $operator \$threshold ? \$leftLabel : \$rightLabel;"); |
|
|
|
|
171
|
|
|
|
172
|
|
|
if ($predicted != $targets[$index]) { |
|
|
|
|
173
|
|
|
$wrong += $this->weights[$index]; |
174
|
|
|
} |
175
|
|
|
} |
176
|
|
|
|
177
|
|
|
return $wrong / $total; |
178
|
|
|
} |
179
|
|
|
} |
180
|
|
|
|
This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.
Consider making the comparison explicit by using
empty(..)
or! empty(...)
instead.