1
|
|
|
<?php |
2
|
|
|
|
3
|
|
|
declare(strict_types=1); |
4
|
|
|
|
5
|
|
|
namespace Phpml\Classification\Linear; |
6
|
|
|
|
7
|
|
|
use Phpml\Helper\Predictable; |
8
|
|
|
use Phpml\Helper\Trainable; |
9
|
|
|
use Phpml\Classification\WeightedClassifier; |
10
|
|
|
use Phpml\Classification\DecisionTree; |
11
|
|
|
|
12
|
|
|
class DecisionStump extends WeightedClassifier |
13
|
|
|
{ |
14
|
|
|
use Trainable, Predictable; |
15
|
|
|
|
16
|
|
|
const AUTO_SELECT = -1; |
17
|
|
|
|
18
|
|
|
/** |
19
|
|
|
* @var int |
20
|
|
|
*/ |
21
|
|
|
protected $givenColumnIndex; |
22
|
|
|
|
23
|
|
|
|
24
|
|
|
/** |
25
|
|
|
* Sample weights : If used the optimization on the decision value |
26
|
|
|
* will take these weights into account. If not given, all samples |
27
|
|
|
* will be weighed with the same value of 1 |
28
|
|
|
* |
29
|
|
|
* @var array |
30
|
|
|
*/ |
31
|
|
|
protected $weights = null; |
32
|
|
|
|
33
|
|
|
/** |
34
|
|
|
* Lowest error rate obtained while training/optimizing the model |
35
|
|
|
* |
36
|
|
|
* @var float |
37
|
|
|
*/ |
38
|
|
|
protected $trainingErrorRate; |
39
|
|
|
|
40
|
|
|
/** |
41
|
|
|
* @var int |
42
|
|
|
*/ |
43
|
|
|
protected $column; |
44
|
|
|
|
45
|
|
|
/** |
46
|
|
|
* @var mixed |
47
|
|
|
*/ |
48
|
|
|
protected $value; |
49
|
|
|
|
50
|
|
|
/** |
51
|
|
|
* @var string |
52
|
|
|
*/ |
53
|
|
|
protected $operator; |
54
|
|
|
|
55
|
|
|
/** |
56
|
|
|
* @var array |
57
|
|
|
*/ |
58
|
|
|
protected $columnTypes; |
59
|
|
|
|
60
|
|
|
/** |
61
|
|
|
* @var float |
62
|
|
|
*/ |
63
|
|
|
protected $numSplitCount = 10.0; |
64
|
|
|
|
65
|
|
|
/** |
66
|
|
|
* A DecisionStump classifier is a one-level deep DecisionTree. It is generally |
67
|
|
|
* used with ensemble algorithms as in the weak classifier role. <br> |
68
|
|
|
* |
69
|
|
|
* If columnIndex is given, then the stump tries to produce a decision node |
70
|
|
|
* on this column, otherwise in cases given the value of -1, the stump itself |
71
|
|
|
* decides which column to take for the decision (Default DecisionTree behaviour) |
72
|
|
|
* |
73
|
|
|
* @param int $columnIndex |
74
|
|
|
*/ |
75
|
|
|
public function __construct(int $columnIndex = self::AUTO_SELECT) |
76
|
|
|
{ |
77
|
|
|
$this->givenColumnIndex = $columnIndex; |
78
|
|
|
} |
79
|
|
|
|
80
|
|
|
/** |
81
|
|
|
* @param array $samples |
82
|
|
|
* @param array $targets |
83
|
|
|
*/ |
84
|
|
|
public function train(array $samples, array $targets) |
85
|
|
|
{ |
86
|
|
|
$this->samples = array_merge($this->samples, $samples); |
87
|
|
|
$this->targets = array_merge($this->targets, $targets); |
88
|
|
|
|
89
|
|
|
// DecisionStump is capable of classifying between two classes only |
90
|
|
|
$labels = array_count_values($this->targets); |
91
|
|
|
$this->labels = array_keys($labels); |
|
|
|
|
92
|
|
|
if (count($this->labels) != 2) { |
93
|
|
|
throw new \Exception("DecisionStump can classify between two classes only:" . implode(',', $this->labels)); |
94
|
|
|
} |
95
|
|
|
|
96
|
|
|
// If a column index is given, it should be among the existing columns |
97
|
|
|
if ($this->givenColumnIndex > count($samples[0]) - 1) { |
98
|
|
|
$this->givenColumnIndex = self::AUTO_SELECT; |
99
|
|
|
} |
100
|
|
|
|
101
|
|
|
// Check the size of the weights given. |
102
|
|
|
// If none given, then assign 1 as a weight to each sample |
103
|
|
|
if ($this->weights) { |
|
|
|
|
104
|
|
|
$numWeights = count($this->weights); |
105
|
|
|
if ($numWeights != count($this->samples)) { |
106
|
|
|
throw new \Exception("Number of sample weights does not match with number of samples"); |
107
|
|
|
} |
108
|
|
|
} else { |
109
|
|
|
$this->weights = array_fill(0, count($samples), 1); |
110
|
|
|
} |
111
|
|
|
|
112
|
|
|
// Determine type of each column as either "continuous" or "nominal" |
113
|
|
|
$this->columnTypes = DecisionTree::getColumnTypes($this->samples); |
114
|
|
|
|
115
|
|
|
// Try to find the best split in the columns of the dataset |
116
|
|
|
// by calculating error rate for each split point in each column |
117
|
|
|
$columns = range(0, count($samples[0]) - 1); |
118
|
|
|
if ($this->givenColumnIndex != self::AUTO_SELECT) { |
119
|
|
|
$columns = [$this->givenColumnIndex]; |
120
|
|
|
} |
121
|
|
|
|
122
|
|
|
$bestSplit = [ |
123
|
|
|
'value' => 0, 'operator' => '', |
124
|
|
|
'column' => 0, 'trainingErrorRate' => 1.0]; |
125
|
|
|
foreach ($columns as $col) { |
126
|
|
|
if ($this->columnTypes[$col] == DecisionTree::CONTINUOS) { |
127
|
|
|
$split = $this->getBestNumericalSplit($col); |
128
|
|
|
} else { |
129
|
|
|
$split = $this->getBestNominalSplit($col); |
130
|
|
|
} |
131
|
|
|
|
132
|
|
|
if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) { |
133
|
|
|
$bestSplit = $split; |
134
|
|
|
} |
135
|
|
|
} |
136
|
|
|
|
137
|
|
|
// Assign determined best values to the stump |
138
|
|
|
foreach ($bestSplit as $name => $value) { |
139
|
|
|
$this->{$name} = $value; |
140
|
|
|
} |
141
|
|
|
} |
142
|
|
|
|
143
|
|
|
/** |
144
|
|
|
* While finding best split point for a numerical valued column, |
145
|
|
|
* DecisionStump looks for equally distanced values between minimum and maximum |
146
|
|
|
* values in the column. Given <i>$count</i> value determines how many split |
147
|
|
|
* points to be probed. The more split counts, the better performance but |
148
|
|
|
* worse processing time (Default value is 10.0) |
149
|
|
|
* |
150
|
|
|
* @param float $count |
151
|
|
|
*/ |
152
|
|
|
public function setNumericalSplitCount(float $count) |
153
|
|
|
{ |
154
|
|
|
$this->numSplitCount = $count; |
155
|
|
|
} |
156
|
|
|
|
157
|
|
|
/** |
158
|
|
|
* Determines best split point for the given column |
159
|
|
|
* |
160
|
|
|
* @param int $col |
161
|
|
|
* |
162
|
|
|
* @return array |
163
|
|
|
*/ |
164
|
|
|
protected function getBestNumericalSplit(int $col) |
165
|
|
|
{ |
166
|
|
|
$values = array_column($this->samples, $col); |
167
|
|
|
$minValue = min($values); |
168
|
|
|
$maxValue = max($values); |
169
|
|
|
$stepSize = ($maxValue - $minValue) / $this->numSplitCount; |
170
|
|
|
|
171
|
|
|
$split = null; |
172
|
|
|
|
173
|
|
|
foreach (['<=', '>'] as $operator) { |
174
|
|
|
// Before trying all possible split points, let's first try |
175
|
|
|
// the average value for the cut point |
176
|
|
|
$threshold = array_sum($values) / (float) count($values); |
177
|
|
|
$errorRate = $this->calculateErrorRate($threshold, $operator, $values); |
178
|
|
View Code Duplication |
if ($split == null || $errorRate < $split['trainingErrorRate']) { |
|
|
|
|
179
|
|
|
$split = ['value' => $threshold, 'operator' => $operator, |
180
|
|
|
'column' => $col, 'trainingErrorRate' => $errorRate]; |
181
|
|
|
} |
182
|
|
|
|
183
|
|
|
// Try other possible points one by one |
184
|
|
|
for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) { |
185
|
|
|
$threshold = (float)$step; |
186
|
|
|
$errorRate = $this->calculateErrorRate($threshold, $operator, $values); |
187
|
|
View Code Duplication |
if ($errorRate < $split['trainingErrorRate']) { |
|
|
|
|
188
|
|
|
$split = ['value' => $threshold, 'operator' => $operator, |
189
|
|
|
'column' => $col, 'trainingErrorRate' => $errorRate]; |
190
|
|
|
} |
191
|
|
|
}// for |
192
|
|
|
} |
193
|
|
|
|
194
|
|
|
return $split; |
195
|
|
|
} |
196
|
|
|
|
197
|
|
|
/** |
198
|
|
|
* |
199
|
|
|
* @param int $col |
200
|
|
|
* |
201
|
|
|
* @return array |
202
|
|
|
*/ |
203
|
|
|
protected function getBestNominalSplit(int $col) |
204
|
|
|
{ |
205
|
|
|
$values = array_column($this->samples, $col); |
206
|
|
|
$valueCounts = array_count_values($values); |
207
|
|
|
$distinctVals= array_keys($valueCounts); |
208
|
|
|
|
209
|
|
|
$split = null; |
210
|
|
|
|
211
|
|
|
foreach (['=', '!='] as $operator) { |
212
|
|
|
foreach ($distinctVals as $val) { |
213
|
|
|
$errorRate = $this->calculateErrorRate($val, $operator, $values); |
214
|
|
|
|
215
|
|
View Code Duplication |
if ($split == null || $split['trainingErrorRate'] < $errorRate) { |
|
|
|
|
216
|
|
|
$split = ['value' => $val, 'operator' => $operator, |
217
|
|
|
'column' => $col, 'trainingErrorRate' => $errorRate]; |
218
|
|
|
} |
219
|
|
|
}// for |
220
|
|
|
} |
221
|
|
|
|
222
|
|
|
return $split; |
223
|
|
|
} |
224
|
|
|
|
225
|
|
|
|
226
|
|
|
/** |
227
|
|
|
* |
228
|
|
|
* @param type $leftValue |
229
|
|
|
* @param type $operator |
230
|
|
|
* @param type $rightValue |
231
|
|
|
* |
232
|
|
|
* @return boolean |
233
|
|
|
*/ |
234
|
|
|
protected function evaluate($leftValue, $operator, $rightValue) |
235
|
|
|
{ |
236
|
|
|
switch ($operator) { |
237
|
|
|
case '>': return $leftValue > $rightValue; |
|
|
|
|
238
|
|
|
case '>=': return $leftValue >= $rightValue; |
|
|
|
|
239
|
|
|
case '<': return $leftValue < $rightValue; |
|
|
|
|
240
|
|
|
case '<=': return $leftValue <= $rightValue; |
|
|
|
|
241
|
|
|
case '=': return $leftValue == $rightValue; |
|
|
|
|
242
|
|
|
case '!=': |
243
|
|
|
case '<>': return $leftValue != $rightValue; |
|
|
|
|
244
|
|
|
} |
245
|
|
|
|
246
|
|
|
return false; |
247
|
|
|
} |
248
|
|
|
|
249
|
|
|
/** |
250
|
|
|
* Calculates the ratio of wrong predictions based on the new threshold |
251
|
|
|
* value given as the parameter |
252
|
|
|
* |
253
|
|
|
* @param float $threshold |
254
|
|
|
* @param string $operator |
255
|
|
|
* @param array $values |
256
|
|
|
*/ |
257
|
|
|
protected function calculateErrorRate(float $threshold, string $operator, array $values) |
258
|
|
|
{ |
259
|
|
|
$total = (float) array_sum($this->weights); |
260
|
|
|
$wrong = 0.0; |
261
|
|
|
$leftLabel = $this->labels[0]; |
262
|
|
|
$rightLabel= $this->labels[1]; |
263
|
|
|
foreach ($values as $index => $value) { |
264
|
|
|
if ($this->evaluate($threshold, $operator, $value)) { |
|
|
|
|
265
|
|
|
$predicted = $leftLabel; |
266
|
|
|
} else { |
267
|
|
|
$predicted = $rightLabel; |
268
|
|
|
} |
269
|
|
|
|
270
|
|
|
if ($predicted != $this->targets[$index]) { |
271
|
|
|
$wrong += $this->weights[$index]; |
272
|
|
|
} |
273
|
|
|
} |
274
|
|
|
|
275
|
|
|
return $wrong / $total; |
276
|
|
|
} |
277
|
|
|
|
278
|
|
|
/** |
279
|
|
|
* @param array $sample |
280
|
|
|
* @return mixed |
281
|
|
|
*/ |
282
|
|
|
protected function predictSample(array $sample) |
283
|
|
|
{ |
284
|
|
|
if ($this->evaluate($this->value, $this->operator, $sample[$this->column])) { |
|
|
|
|
285
|
|
|
return $this->labels[0]; |
286
|
|
|
} |
287
|
|
|
return $this->labels[1]; |
288
|
|
|
} |
289
|
|
|
|
290
|
|
|
public function __toString() |
291
|
|
|
{ |
292
|
|
|
return "$this->column $this->operator $this->value"; |
293
|
|
|
} |
294
|
|
|
} |
295
|
|
|
|
In PHP it is possible to write to properties without declaring them. For example, the following is perfectly valid PHP code:
Generally, it is a good practice to explictly declare properties to avoid accidental typos and provide IDE auto-completion: