These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more
1 | <?php |
||
2 | |||
3 | declare(strict_types=1); |
||
4 | |||
5 | namespace Phpml\Classification\Linear; |
||
6 | |||
7 | use Phpml\Helper\Predictable; |
||
8 | use Phpml\Helper\OneVsRest; |
||
9 | use Phpml\Classification\WeightedClassifier; |
||
10 | use Phpml\Classification\DecisionTree; |
||
11 | |||
12 | class DecisionStump extends WeightedClassifier |
||
13 | { |
||
14 | use Predictable, OneVsRest; |
||
15 | |||
16 | const AUTO_SELECT = -1; |
||
17 | |||
18 | /** |
||
19 | * @var int |
||
20 | */ |
||
21 | protected $givenColumnIndex; |
||
22 | |||
23 | /** |
||
24 | * @var array |
||
25 | */ |
||
26 | protected $binaryLabels; |
||
27 | |||
28 | /** |
||
29 | * Lowest error rate obtained while training/optimizing the model |
||
30 | * |
||
31 | * @var float |
||
32 | */ |
||
33 | protected $trainingErrorRate; |
||
34 | |||
35 | /** |
||
36 | * @var int |
||
37 | */ |
||
38 | protected $column; |
||
39 | |||
40 | /** |
||
41 | * @var mixed |
||
42 | */ |
||
43 | protected $value; |
||
44 | |||
45 | /** |
||
46 | * @var string |
||
47 | */ |
||
48 | protected $operator; |
||
49 | |||
50 | /** |
||
51 | * @var array |
||
52 | */ |
||
53 | protected $columnTypes; |
||
54 | |||
55 | /** |
||
56 | * @var int |
||
57 | */ |
||
58 | protected $featureCount; |
||
59 | |||
60 | /** |
||
61 | * @var float |
||
62 | */ |
||
63 | protected $numSplitCount = 100.0; |
||
64 | |||
65 | /** |
||
66 | * Distribution of samples in the leaves |
||
67 | * |
||
68 | * @var array |
||
69 | */ |
||
70 | protected $prob; |
||
71 | |||
72 | /** |
||
73 | * A DecisionStump classifier is a one-level deep DecisionTree. It is generally |
||
74 | * used with ensemble algorithms as in the weak classifier role. <br> |
||
75 | * |
||
76 | * If columnIndex is given, then the stump tries to produce a decision node |
||
77 | * on this column, otherwise in cases given the value of -1, the stump itself |
||
78 | * decides which column to take for the decision (Default DecisionTree behaviour) |
||
79 | * |
||
80 | * @param int $columnIndex |
||
81 | */ |
||
82 | public function __construct(int $columnIndex = self::AUTO_SELECT) |
||
83 | { |
||
84 | $this->givenColumnIndex = $columnIndex; |
||
85 | } |
||
86 | |||
87 | /** |
||
88 | * @param array $samples |
||
89 | * @param array $targets |
||
90 | * @param array $labels |
||
91 | * |
||
92 | * @throws \Exception |
||
93 | */ |
||
94 | protected function trainBinary(array $samples, array $targets, array $labels) |
||
95 | { |
||
96 | $this->binaryLabels = $labels; |
||
97 | $this->featureCount = count($samples[0]); |
||
98 | |||
99 | // If a column index is given, it should be among the existing columns |
||
100 | if ($this->givenColumnIndex > count($samples[0]) - 1) { |
||
101 | $this->givenColumnIndex = self::AUTO_SELECT; |
||
102 | } |
||
103 | |||
104 | // Check the size of the weights given. |
||
105 | // If none given, then assign 1 as a weight to each sample |
||
106 | if ($this->weights) { |
||
0 ignored issues
–
show
|
|||
107 | $numWeights = count($this->weights); |
||
108 | if ($numWeights != count($samples)) { |
||
109 | throw new \Exception("Number of sample weights does not match with number of samples"); |
||
110 | } |
||
111 | } else { |
||
112 | $this->weights = array_fill(0, count($samples), 1); |
||
113 | } |
||
114 | |||
115 | // Determine type of each column as either "continuous" or "nominal" |
||
116 | $this->columnTypes = DecisionTree::getColumnTypes($samples); |
||
117 | |||
118 | // Try to find the best split in the columns of the dataset |
||
119 | // by calculating error rate for each split point in each column |
||
120 | $columns = range(0, count($samples[0]) - 1); |
||
121 | if ($this->givenColumnIndex != self::AUTO_SELECT) { |
||
122 | $columns = [$this->givenColumnIndex]; |
||
123 | } |
||
124 | |||
125 | $bestSplit = [ |
||
126 | 'value' => 0, 'operator' => '', |
||
127 | 'prob' => [], 'column' => 0, |
||
128 | 'trainingErrorRate' => 1.0]; |
||
129 | foreach ($columns as $col) { |
||
130 | if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) { |
||
131 | $split = $this->getBestNumericalSplit($samples, $targets, $col); |
||
132 | } else { |
||
133 | $split = $this->getBestNominalSplit($samples, $targets, $col); |
||
134 | } |
||
135 | |||
136 | if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) { |
||
137 | $bestSplit = $split; |
||
138 | } |
||
139 | } |
||
140 | |||
141 | // Assign determined best values to the stump |
||
142 | foreach ($bestSplit as $name => $value) { |
||
143 | $this->{$name} = $value; |
||
144 | } |
||
145 | } |
||
146 | |||
147 | /** |
||
148 | * While finding best split point for a numerical valued column, |
||
149 | * DecisionStump looks for equally distanced values between minimum and maximum |
||
150 | * values in the column. Given <i>$count</i> value determines how many split |
||
151 | * points to be probed. The more split counts, the better performance but |
||
152 | * worse processing time (Default value is 10.0) |
||
153 | * |
||
154 | * @param float $count |
||
155 | */ |
||
156 | public function setNumericalSplitCount(float $count) |
||
157 | { |
||
158 | $this->numSplitCount = $count; |
||
159 | } |
||
160 | |||
161 | /** |
||
162 | * Determines best split point for the given column |
||
163 | * |
||
164 | * @param array $samples |
||
165 | * @param array $targets |
||
166 | * @param int $col |
||
167 | * |
||
168 | * @return array |
||
169 | */ |
||
170 | protected function getBestNumericalSplit(array $samples, array $targets, int $col) |
||
171 | { |
||
172 | $values = array_column($samples, $col); |
||
173 | // Trying all possible points may be accomplished in two general ways: |
||
174 | // 1- Try all values in the $samples array ($values) |
||
175 | // 2- Artificially split the range of values into several parts and try them |
||
176 | // We choose the second one because it is faster in larger datasets |
||
177 | $minValue = min($values); |
||
178 | $maxValue = max($values); |
||
179 | $stepSize = ($maxValue - $minValue) / $this->numSplitCount; |
||
180 | |||
181 | $split = null; |
||
182 | |||
183 | foreach (['<=', '>'] as $operator) { |
||
184 | // Before trying all possible split points, let's first try |
||
185 | // the average value for the cut point |
||
186 | $threshold = array_sum($values) / (float) count($values); |
||
187 | list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values); |
||
188 | View Code Duplication | if ($split == null || $errorRate < $split['trainingErrorRate']) { |
|
0 ignored issues
–
show
This code seems to be duplicated across your project.
Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation. You can also find more detailed suggestions in the “Code” section of your repository.
Loading history...
|
|||
189 | $split = ['value' => $threshold, 'operator' => $operator, |
||
190 | 'prob' => $prob, 'column' => $col, |
||
191 | 'trainingErrorRate' => $errorRate]; |
||
192 | } |
||
193 | |||
194 | // Try other possible points one by one |
||
195 | for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) { |
||
196 | $threshold = (float)$step; |
||
197 | list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values); |
||
198 | View Code Duplication | if ($errorRate < $split['trainingErrorRate']) { |
|
0 ignored issues
–
show
This code seems to be duplicated across your project.
Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation. You can also find more detailed suggestions in the “Code” section of your repository.
Loading history...
|
|||
199 | $split = ['value' => $threshold, 'operator' => $operator, |
||
200 | 'prob' => $prob, 'column' => $col, |
||
201 | 'trainingErrorRate' => $errorRate]; |
||
202 | } |
||
203 | }// for |
||
204 | } |
||
205 | |||
206 | return $split; |
||
207 | } |
||
208 | |||
209 | /** |
||
210 | * @param array $samples |
||
211 | * @param array $targets |
||
212 | * @param int $col |
||
213 | * |
||
214 | * @return array |
||
215 | */ |
||
216 | protected function getBestNominalSplit(array $samples, array $targets, int $col) : array |
||
217 | { |
||
218 | $values = array_column($samples, $col); |
||
219 | $valueCounts = array_count_values($values); |
||
220 | $distinctVals= array_keys($valueCounts); |
||
221 | |||
222 | $split = null; |
||
223 | |||
224 | foreach (['=', '!='] as $operator) { |
||
225 | foreach ($distinctVals as $val) { |
||
226 | list($errorRate, $prob) = $this->calculateErrorRate($targets, $val, $operator, $values); |
||
227 | |||
228 | View Code Duplication | if ($split == null || $split['trainingErrorRate'] < $errorRate) { |
|
0 ignored issues
–
show
This code seems to be duplicated across your project.
Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation. You can also find more detailed suggestions in the “Code” section of your repository.
Loading history...
|
|||
229 | $split = ['value' => $val, 'operator' => $operator, |
||
230 | 'prob' => $prob, 'column' => $col, |
||
231 | 'trainingErrorRate' => $errorRate]; |
||
232 | } |
||
233 | } |
||
234 | } |
||
235 | |||
236 | return $split; |
||
237 | } |
||
238 | |||
239 | |||
240 | /** |
||
241 | * |
||
242 | * @param mixed $leftValue |
||
243 | * @param string $operator |
||
244 | * @param mixed $rightValue |
||
245 | * |
||
246 | * @return boolean |
||
247 | */ |
||
248 | protected function evaluate($leftValue, string $operator, $rightValue) |
||
249 | { |
||
250 | switch ($operator) { |
||
251 | case '>': return $leftValue > $rightValue; |
||
0 ignored issues
–
show
The case body in a switch statement must start on the line following the statement.
According to the PSR-2, the body of a case statement must start on the line immediately following the case statement. switch ($expr) {
case "A":
doSomething(); //right
break;
case "B":
doSomethingElse(); //wrong
break;
} To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.
Loading history...
Terminating statement must be on a line by itself
As per the PSR-2 coding standard, the switch ($expr) {
case "A":
doSomething();
break; //wrong
case "B":
doSomething();
break; //right
case "C:":
doSomething();
return true; //right
}
To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.
Loading history...
|
|||
252 | case '>=': return $leftValue >= $rightValue; |
||
0 ignored issues
–
show
The case body in a switch statement must start on the line following the statement.
According to the PSR-2, the body of a case statement must start on the line immediately following the case statement. switch ($expr) {
case "A":
doSomething(); //right
break;
case "B":
doSomethingElse(); //wrong
break;
} To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.
Loading history...
Terminating statement must be on a line by itself
As per the PSR-2 coding standard, the switch ($expr) {
case "A":
doSomething();
break; //wrong
case "B":
doSomething();
break; //right
case "C:":
doSomething();
return true; //right
}
To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.
Loading history...
|
|||
253 | case '<': return $leftValue < $rightValue; |
||
0 ignored issues
–
show
The case body in a switch statement must start on the line following the statement.
According to the PSR-2, the body of a case statement must start on the line immediately following the case statement. switch ($expr) {
case "A":
doSomething(); //right
break;
case "B":
doSomethingElse(); //wrong
break;
} To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.
Loading history...
Terminating statement must be on a line by itself
As per the PSR-2 coding standard, the switch ($expr) {
case "A":
doSomething();
break; //wrong
case "B":
doSomething();
break; //right
case "C:":
doSomething();
return true; //right
}
To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.
Loading history...
|
|||
254 | case '<=': return $leftValue <= $rightValue; |
||
0 ignored issues
–
show
The case body in a switch statement must start on the line following the statement.
According to the PSR-2, the body of a case statement must start on the line immediately following the case statement. switch ($expr) {
case "A":
doSomething(); //right
break;
case "B":
doSomethingElse(); //wrong
break;
} To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.
Loading history...
Terminating statement must be on a line by itself
As per the PSR-2 coding standard, the switch ($expr) {
case "A":
doSomething();
break; //wrong
case "B":
doSomething();
break; //right
case "C:":
doSomething();
return true; //right
}
To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.
Loading history...
|
|||
255 | case '=': return $leftValue === $rightValue; |
||
0 ignored issues
–
show
The case body in a switch statement must start on the line following the statement.
According to the PSR-2, the body of a case statement must start on the line immediately following the case statement. switch ($expr) {
case "A":
doSomething(); //right
break;
case "B":
doSomethingElse(); //wrong
break;
} To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.
Loading history...
Terminating statement must be on a line by itself
As per the PSR-2 coding standard, the switch ($expr) {
case "A":
doSomething();
break; //wrong
case "B":
doSomething();
break; //right
case "C:":
doSomething();
return true; //right
}
To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.
Loading history...
|
|||
256 | case '!=': |
||
257 | case '<>': return $leftValue !== $rightValue; |
||
0 ignored issues
–
show
The case body in a switch statement must start on the line following the statement.
According to the PSR-2, the body of a case statement must start on the line immediately following the case statement. switch ($expr) {
case "A":
doSomething(); //right
break;
case "B":
doSomethingElse(); //wrong
break;
} To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.
Loading history...
Terminating statement must be on a line by itself
As per the PSR-2 coding standard, the switch ($expr) {
case "A":
doSomething();
break; //wrong
case "B":
doSomething();
break; //right
case "C:":
doSomething();
return true; //right
}
To learn more about the PSR-2 coding standard, please refer to the PHP-Fig.
Loading history...
|
|||
258 | } |
||
259 | |||
260 | return false; |
||
261 | } |
||
262 | |||
263 | /** |
||
264 | * Calculates the ratio of wrong predictions based on the new threshold |
||
265 | * value given as the parameter |
||
266 | * |
||
267 | * @param array $targets |
||
268 | * @param float $threshold |
||
269 | * @param string $operator |
||
270 | * @param array $values |
||
271 | * |
||
272 | * @return array |
||
273 | */ |
||
274 | protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values) : array |
||
275 | { |
||
276 | $wrong = 0.0; |
||
277 | $prob = []; |
||
278 | $leftLabel = $this->binaryLabels[0]; |
||
279 | $rightLabel= $this->binaryLabels[1]; |
||
280 | |||
281 | foreach ($values as $index => $value) { |
||
282 | if ($this->evaluate($value, $operator, $threshold)) { |
||
283 | $predicted = $leftLabel; |
||
284 | } else { |
||
285 | $predicted = $rightLabel; |
||
286 | } |
||
287 | |||
288 | $target = $targets[$index]; |
||
289 | if (strval($predicted) != strval($targets[$index])) { |
||
290 | $wrong += $this->weights[$index]; |
||
291 | } |
||
292 | |||
293 | if (!isset($prob[$predicted][$target])) { |
||
294 | $prob[$predicted][$target] = 0; |
||
295 | } |
||
296 | ++$prob[$predicted][$target]; |
||
297 | } |
||
298 | |||
299 | // Calculate probabilities: Proportion of labels in each leaf |
||
300 | $dist = array_combine($this->binaryLabels, array_fill(0, 2, 0.0)); |
||
301 | foreach ($prob as $leaf => $counts) { |
||
302 | $leafTotal = (float)array_sum($prob[$leaf]); |
||
303 | foreach ($counts as $label => $count) { |
||
304 | if (strval($leaf) == strval($label)) { |
||
305 | $dist[$leaf] = $count / $leafTotal; |
||
306 | } |
||
307 | } |
||
308 | } |
||
309 | |||
310 | return [$wrong / (float) array_sum($this->weights), $dist]; |
||
311 | } |
||
312 | |||
313 | /** |
||
314 | * Returns the probability of the sample of belonging to the given label |
||
315 | * |
||
316 | * Probability of a sample is calculated as the proportion of the label |
||
317 | * within the labels of the training samples in the decision node |
||
318 | * |
||
319 | * @param array $sample |
||
320 | * @param mixed $label |
||
321 | * |
||
322 | * @return float |
||
323 | */ |
||
324 | protected function predictProbability(array $sample, $label) : float |
||
325 | { |
||
326 | $predicted = $this->predictSampleBinary($sample); |
||
327 | if (strval($predicted) == strval($label)) { |
||
328 | return $this->prob[$label]; |
||
329 | } |
||
330 | |||
331 | return 0.0; |
||
332 | } |
||
333 | |||
334 | /** |
||
335 | * @param array $sample |
||
336 | * |
||
337 | * @return mixed |
||
338 | */ |
||
339 | protected function predictSampleBinary(array $sample) |
||
340 | { |
||
341 | if ($this->evaluate($sample[$this->column], $this->operator, $this->value)) { |
||
342 | return $this->binaryLabels[0]; |
||
343 | } |
||
344 | |||
345 | return $this->binaryLabels[1]; |
||
346 | } |
||
347 | |||
348 | /** |
||
349 | * @return void |
||
350 | */ |
||
351 | protected function resetBinary() |
||
352 | { |
||
353 | } |
||
354 | |||
355 | /** |
||
356 | * @return string |
||
357 | */ |
||
358 | public function __toString() |
||
359 | { |
||
360 | return "IF $this->column $this->operator $this->value " . |
||
361 | "THEN " . $this->binaryLabels[0] . " ". |
||
362 | "ELSE " . $this->binaryLabels[1]; |
||
363 | } |
||
364 | } |
||
365 |
This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.
Consider making the comparison explicit by using
empty(..)
or! empty(...)
instead.