These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more
1 | <?php |
||
2 | |||
3 | declare(strict_types=1); |
||
4 | |||
5 | namespace Phpml\Classification; |
||
6 | |||
7 | use Phpml\Classification\DecisionTree\DecisionTreeLeaf; |
||
8 | use Phpml\Exception\InvalidArgumentException; |
||
9 | use Phpml\Helper\Predictable; |
||
10 | use Phpml\Helper\Trainable; |
||
11 | use Phpml\Math\Statistic\Mean; |
||
12 | |||
13 | class DecisionTree implements Classifier |
||
14 | { |
||
15 | use Trainable, Predictable; |
||
16 | |||
17 | public const CONTINUOUS = 1; |
||
18 | |||
19 | public const NOMINAL = 2; |
||
20 | |||
21 | /** |
||
22 | * @var int |
||
23 | */ |
||
24 | public $actualDepth = 0; |
||
25 | |||
26 | /** |
||
27 | * @var array |
||
28 | */ |
||
29 | protected $columnTypes = []; |
||
30 | |||
31 | /** |
||
32 | * @var DecisionTreeLeaf |
||
33 | */ |
||
34 | protected $tree = null; |
||
35 | |||
36 | /** |
||
37 | * @var int |
||
38 | */ |
||
39 | protected $maxDepth; |
||
40 | |||
41 | /** |
||
42 | * @var array |
||
43 | */ |
||
44 | private $labels = []; |
||
45 | |||
46 | /** |
||
47 | * @var int |
||
48 | */ |
||
49 | private $featureCount = 0; |
||
50 | |||
51 | /** |
||
52 | * @var int |
||
53 | */ |
||
54 | private $numUsableFeatures = 0; |
||
55 | |||
56 | /** |
||
57 | * @var array |
||
58 | */ |
||
59 | private $selectedFeatures = []; |
||
60 | |||
61 | /** |
||
62 | * @var array|null |
||
63 | */ |
||
64 | private $featureImportances; |
||
65 | |||
66 | /** |
||
67 | * @var array |
||
68 | */ |
||
69 | private $columnNames = []; |
||
70 | |||
71 | public function __construct(int $maxDepth = 10) |
||
72 | { |
||
73 | $this->maxDepth = $maxDepth; |
||
74 | } |
||
75 | |||
76 | public function train(array $samples, array $targets): void |
||
77 | { |
||
78 | $this->samples = array_merge($this->samples, $samples); |
||
79 | $this->targets = array_merge($this->targets, $targets); |
||
80 | |||
81 | $this->featureCount = count($this->samples[0]); |
||
82 | $this->columnTypes = self::getColumnTypes($this->samples); |
||
83 | $this->labels = array_keys(array_count_values($this->targets)); |
||
84 | $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1)); |
||
85 | |||
86 | // Each time the tree is trained, feature importances are reset so that |
||
87 | // we will have to compute it again depending on the new data |
||
88 | $this->featureImportances = null; |
||
89 | |||
90 | // If column names are given or computed before, then there is no |
||
91 | // need to init it and accidentally remove the previous given names |
||
92 | if ($this->columnNames === []) { |
||
93 | $this->columnNames = range(0, $this->featureCount - 1); |
||
94 | } elseif (count($this->columnNames) > $this->featureCount) { |
||
95 | $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount); |
||
96 | } elseif (count($this->columnNames) < $this->featureCount) { |
||
97 | $this->columnNames = array_merge( |
||
98 | $this->columnNames, |
||
99 | range(count($this->columnNames), $this->featureCount - 1) |
||
100 | ); |
||
101 | } |
||
102 | } |
||
103 | |||
104 | public static function getColumnTypes(array $samples): array |
||
105 | { |
||
106 | $types = []; |
||
107 | $featureCount = count($samples[0]); |
||
108 | for ($i = 0; $i < $featureCount; ++$i) { |
||
109 | $values = array_column($samples, $i); |
||
110 | $isCategorical = self::isCategoricalColumn($values); |
||
111 | $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS; |
||
112 | } |
||
113 | |||
114 | return $types; |
||
115 | } |
||
116 | |||
117 | /** |
||
118 | * @param mixed $baseValue |
||
119 | */ |
||
120 | public function getGiniIndex($baseValue, array $colValues, array $targets): float |
||
121 | { |
||
122 | $countMatrix = []; |
||
123 | foreach ($this->labels as $label) { |
||
124 | $countMatrix[$label] = [0, 0]; |
||
125 | } |
||
126 | |||
127 | foreach ($colValues as $index => $value) { |
||
128 | $label = $targets[$index]; |
||
129 | $rowIndex = $value === $baseValue ? 0 : 1; |
||
130 | ++$countMatrix[$label][$rowIndex]; |
||
131 | } |
||
132 | |||
133 | $giniParts = [0, 0]; |
||
134 | for ($i = 0; $i <= 1; ++$i) { |
||
135 | $part = 0; |
||
136 | $sum = array_sum(array_column($countMatrix, $i)); |
||
137 | if ($sum > 0) { |
||
138 | foreach ($this->labels as $label) { |
||
139 | $part += pow($countMatrix[$label][$i] / (float) $sum, 2); |
||
140 | } |
||
141 | } |
||
142 | |||
143 | $giniParts[$i] = (1 - $part) * $sum; |
||
144 | } |
||
145 | |||
146 | return array_sum($giniParts) / count($colValues); |
||
147 | } |
||
148 | |||
149 | /** |
||
150 | * This method is used to set number of columns to be used |
||
151 | * when deciding a split at an internal node of the tree. <br> |
||
152 | * If the value is given 0, then all features are used (default behaviour), |
||
153 | * otherwise the given value will be used as a maximum for number of columns |
||
154 | * randomly selected for each split operation. |
||
155 | * |
||
156 | * @return $this |
||
157 | * |
||
158 | * @throws InvalidArgumentException |
||
159 | */ |
||
160 | public function setNumFeatures(int $numFeatures) |
||
161 | { |
||
162 | if ($numFeatures < 0) { |
||
163 | throw new InvalidArgumentException('Selected column count should be greater or equal to zero'); |
||
164 | } |
||
165 | |||
166 | $this->numUsableFeatures = $numFeatures; |
||
167 | |||
168 | return $this; |
||
169 | } |
||
170 | |||
171 | /** |
||
172 | * A string array to represent columns. Useful when HTML output or |
||
173 | * column importances are desired to be inspected. |
||
174 | * |
||
175 | * @return $this |
||
176 | * |
||
177 | * @throws InvalidArgumentException |
||
178 | */ |
||
179 | public function setColumnNames(array $names) |
||
180 | { |
||
181 | if ($this->featureCount !== 0 && count($names) !== $this->featureCount) { |
||
182 | throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount)); |
||
183 | } |
||
184 | |||
185 | $this->columnNames = $names; |
||
186 | |||
187 | return $this; |
||
188 | } |
||
189 | |||
190 | public function getHtml(): string |
||
191 | { |
||
192 | return $this->tree->getHTML($this->columnNames); |
||
193 | } |
||
194 | |||
195 | /** |
||
196 | * This will return an array including an importance value for |
||
197 | * each column in the given dataset. The importance values are |
||
198 | * normalized and their total makes 1.<br/> |
||
199 | */ |
||
200 | public function getFeatureImportances(): array |
||
201 | { |
||
202 | if ($this->featureImportances !== null) { |
||
203 | return $this->featureImportances; |
||
204 | } |
||
205 | |||
206 | $sampleCount = count($this->samples); |
||
207 | $this->featureImportances = []; |
||
208 | foreach ($this->columnNames as $column => $columnName) { |
||
209 | $nodes = $this->getSplitNodesByColumn($column, $this->tree); |
||
210 | |||
211 | $importance = 0; |
||
212 | foreach ($nodes as $node) { |
||
213 | $importance += $node->getNodeImpurityDecrease($sampleCount); |
||
214 | } |
||
215 | |||
216 | $this->featureImportances[$columnName] = $importance; |
||
217 | } |
||
218 | |||
219 | // Normalize & sort the importances |
||
220 | $total = array_sum($this->featureImportances); |
||
221 | if ($total > 0) { |
||
222 | foreach ($this->featureImportances as &$importance) { |
||
223 | $importance /= $total; |
||
224 | } |
||
225 | |||
226 | arsort($this->featureImportances); |
||
227 | } |
||
228 | |||
229 | return $this->featureImportances; |
||
230 | } |
||
231 | |||
232 | protected function getSplitLeaf(array $records, int $depth = 0): DecisionTreeLeaf |
||
233 | { |
||
234 | $split = $this->getBestSplit($records); |
||
235 | $split->level = $depth; |
||
236 | if ($this->actualDepth < $depth) { |
||
237 | $this->actualDepth = $depth; |
||
238 | } |
||
239 | |||
240 | // Traverse all records to see if all records belong to the same class, |
||
241 | // otherwise group the records so that we can classify the leaf |
||
242 | // in case maximum depth is reached |
||
243 | $leftRecords = []; |
||
244 | $rightRecords = []; |
||
245 | $remainingTargets = []; |
||
246 | $prevRecord = null; |
||
247 | $allSame = true; |
||
248 | |||
249 | foreach ($records as $recordNo) { |
||
250 | // Check if the previous record is the same with the current one |
||
251 | $record = $this->samples[$recordNo]; |
||
252 | if ($prevRecord && $prevRecord != $record) { |
||
253 | $allSame = false; |
||
254 | } |
||
255 | |||
256 | $prevRecord = $record; |
||
257 | |||
258 | // According to the split criteron, this record will |
||
259 | // belong to either left or the right side in the next split |
||
260 | if ($split->evaluate($record)) { |
||
261 | $leftRecords[] = $recordNo; |
||
262 | } else { |
||
263 | $rightRecords[] = $recordNo; |
||
264 | } |
||
265 | |||
266 | // Group remaining targets |
||
267 | $target = $this->targets[$recordNo]; |
||
268 | if (!array_key_exists($target, $remainingTargets)) { |
||
269 | $remainingTargets[$target] = 1; |
||
270 | } else { |
||
271 | ++$remainingTargets[$target]; |
||
272 | } |
||
273 | } |
||
274 | |||
275 | if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) { |
||
276 | $split->isTerminal = 1; |
||
277 | arsort($remainingTargets); |
||
278 | $split->classValue = key($remainingTargets); |
||
279 | } else { |
||
280 | if ($leftRecords) { |
||
0 ignored issues
–
show
|
|||
281 | $split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1); |
||
282 | } |
||
283 | |||
284 | if ($rightRecords) { |
||
0 ignored issues
–
show
The expression
$rightRecords of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.
This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent. Consider making the comparison explicit by using
Loading history...
|
|||
285 | $split->rightLeaf = $this->getSplitLeaf($rightRecords, $depth + 1); |
||
286 | } |
||
287 | } |
||
288 | |||
289 | return $split; |
||
290 | } |
||
291 | |||
292 | protected function getBestSplit(array $records): DecisionTreeLeaf |
||
293 | { |
||
294 | $targets = array_intersect_key($this->targets, array_flip($records)); |
||
295 | $samples = array_intersect_key($this->samples, array_flip($records)); |
||
296 | $samples = array_combine($records, $this->preprocess($samples)); |
||
297 | $bestGiniVal = 1; |
||
298 | $bestSplit = null; |
||
299 | $features = $this->getSelectedFeatures(); |
||
300 | foreach ($features as $i) { |
||
301 | $colValues = []; |
||
302 | foreach ($samples as $index => $row) { |
||
303 | $colValues[$index] = $row[$i]; |
||
304 | } |
||
305 | |||
306 | $counts = array_count_values($colValues); |
||
307 | arsort($counts); |
||
308 | $baseValue = key($counts); |
||
309 | $gini = $this->getGiniIndex($baseValue, $colValues, $targets); |
||
310 | if ($bestSplit === null || $bestGiniVal > $gini) { |
||
311 | $split = new DecisionTreeLeaf(); |
||
312 | $split->value = $baseValue; |
||
313 | $split->giniIndex = $gini; |
||
314 | $split->columnIndex = $i; |
||
315 | $split->isContinuous = $this->columnTypes[$i] == self::CONTINUOUS; |
||
316 | $split->records = $records; |
||
317 | |||
318 | // If a numeric column is to be selected, then |
||
319 | // the original numeric value and the selected operator |
||
320 | // will also be saved into the leaf for future access |
||
321 | if ($this->columnTypes[$i] == self::CONTINUOUS) { |
||
322 | $matches = []; |
||
323 | preg_match("/^([<>=]{1,2})\s*(.*)/", (string) $split->value, $matches); |
||
324 | $split->operator = $matches[1]; |
||
325 | $split->numericValue = (float) $matches[2]; |
||
326 | } |
||
327 | |||
328 | $bestSplit = $split; |
||
329 | $bestGiniVal = $gini; |
||
330 | } |
||
331 | } |
||
332 | |||
333 | return $bestSplit; |
||
334 | } |
||
335 | |||
336 | /** |
||
337 | * Returns available features/columns to the tree for the decision making |
||
338 | * process. <br> |
||
339 | * |
||
340 | * If a number is given with setNumFeatures() method, then a random selection |
||
341 | * of features up to this number is returned. <br> |
||
342 | * |
||
343 | * If some features are manually selected by use of setSelectedFeatures(), |
||
344 | * then only these features are returned <br> |
||
345 | * |
||
346 | * If any of above methods were not called beforehand, then all features |
||
347 | * are returned by default. |
||
348 | */ |
||
349 | protected function getSelectedFeatures(): array |
||
350 | { |
||
351 | $allFeatures = range(0, $this->featureCount - 1); |
||
352 | if ($this->numUsableFeatures === 0 && !$this->selectedFeatures) { |
||
353 | return $allFeatures; |
||
354 | } |
||
355 | |||
356 | if ($this->selectedFeatures) { |
||
0 ignored issues
–
show
The expression
$this->selectedFeatures of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.
This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent. Consider making the comparison explicit by using
Loading history...
|
|||
357 | return $this->selectedFeatures; |
||
358 | } |
||
359 | |||
360 | $numFeatures = $this->numUsableFeatures; |
||
361 | if ($numFeatures > $this->featureCount) { |
||
362 | $numFeatures = $this->featureCount; |
||
363 | } |
||
364 | |||
365 | shuffle($allFeatures); |
||
366 | $selectedFeatures = array_slice($allFeatures, 0, $numFeatures, false); |
||
367 | sort($selectedFeatures); |
||
368 | |||
369 | return $selectedFeatures; |
||
370 | } |
||
371 | |||
372 | protected function preprocess(array $samples): array |
||
373 | { |
||
374 | // Detect and convert continuous data column values into |
||
375 | // discrete values by using the median as a threshold value |
||
376 | $columns = []; |
||
377 | for ($i = 0; $i < $this->featureCount; ++$i) { |
||
378 | $values = array_column($samples, $i); |
||
379 | if ($this->columnTypes[$i] == self::CONTINUOUS) { |
||
380 | $median = Mean::median($values); |
||
381 | foreach ($values as &$value) { |
||
382 | if ($value <= $median) { |
||
383 | $value = "<= ${median}"; |
||
384 | } else { |
||
385 | $value = "> ${median}"; |
||
386 | } |
||
387 | } |
||
388 | } |
||
389 | |||
390 | $columns[] = $values; |
||
391 | } |
||
392 | |||
393 | // Below method is a strange yet very simple & efficient method |
||
394 | // to get the transpose of a 2D array |
||
395 | return array_map(null, ...$columns); |
||
396 | } |
||
397 | |||
398 | protected static function isCategoricalColumn(array $columnValues): bool |
||
399 | { |
||
400 | $count = count($columnValues); |
||
401 | |||
402 | // There are two main indicators that *may* show whether a |
||
403 | // column is composed of discrete set of values: |
||
404 | // 1- Column may contain string values and non-float values |
||
405 | // 2- Number of unique values in the column is only a small fraction of |
||
406 | // all values in that column (Lower than or equal to %20 of all values) |
||
407 | $numericValues = array_filter($columnValues, 'is_numeric'); |
||
408 | $floatValues = array_filter($columnValues, 'is_float'); |
||
409 | if ($floatValues) { |
||
0 ignored issues
–
show
The expression
$floatValues of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.
This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent. Consider making the comparison explicit by using
Loading history...
|
|||
410 | return false; |
||
411 | } |
||
412 | |||
413 | if (count($numericValues) !== $count) { |
||
414 | return true; |
||
415 | } |
||
416 | |||
417 | $distinctValues = array_count_values($columnValues); |
||
418 | |||
419 | return count($distinctValues) <= $count / 5; |
||
420 | } |
||
421 | |||
422 | /** |
||
423 | * Used to set predefined features to consider while deciding which column to use for a split |
||
424 | */ |
||
425 | protected function setSelectedFeatures(array $selectedFeatures): void |
||
426 | { |
||
427 | $this->selectedFeatures = $selectedFeatures; |
||
428 | } |
||
429 | |||
430 | /** |
||
431 | * Collects and returns an array of internal nodes that use the given |
||
432 | * column as a split criterion |
||
433 | */ |
||
434 | protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node): array |
||
435 | { |
||
436 | if (!$node || $node->isTerminal) { |
||
437 | return []; |
||
438 | } |
||
439 | |||
440 | $nodes = []; |
||
441 | if ($node->columnIndex === $column) { |
||
442 | $nodes[] = $node; |
||
443 | } |
||
444 | |||
445 | $lNodes = []; |
||
446 | $rNodes = []; |
||
447 | if ($node->leftLeaf) { |
||
448 | $lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf); |
||
449 | } |
||
450 | |||
451 | if ($node->rightLeaf) { |
||
452 | $rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf); |
||
453 | } |
||
454 | |||
455 | $nodes = array_merge($nodes, $lNodes, $rNodes); |
||
456 | |||
457 | return $nodes; |
||
458 | } |
||
459 | |||
460 | /** |
||
461 | * @return mixed |
||
462 | */ |
||
463 | protected function predictSample(array $sample) |
||
464 | { |
||
465 | $node = $this->tree; |
||
466 | do { |
||
467 | if ($node->isTerminal) { |
||
468 | break; |
||
469 | } |
||
470 | |||
471 | if ($node->evaluate($sample)) { |
||
472 | $node = $node->leftLeaf; |
||
473 | } else { |
||
474 | $node = $node->rightLeaf; |
||
475 | } |
||
476 | } while ($node); |
||
477 | |||
478 | return $node ? $node->classValue : $this->labels[0]; |
||
479 | } |
||
480 | } |
||
481 |
This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.
Consider making the comparison explicit by using
empty(..)
or! empty(...)
instead.