Passed
Push — master ( 47cdff...ed5fc8 )
by Arkadiusz
03:38
created

src/Phpml/Classification/DecisionTree.php (7 issues)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Exception\InvalidArgumentException;
8
use Phpml\Helper\Predictable;
9
use Phpml\Helper\Trainable;
10
use Phpml\Math\Statistic\Mean;
11
use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
12
13
class DecisionTree implements Classifier
14
{
15
    use Trainable, Predictable;
16
17
    const CONTINUOUS = 1;
18
    const NOMINAL = 2;
19
20
    /**
21
     * @var array
22
     */
23
    protected $columnTypes;
24
25
    /**
26
     * @var array
27
     */
28
    private $labels = [];
29
30
    /**
31
     * @var int
32
     */
33
    private $featureCount = 0;
34
35
    /**
36
     * @var DecisionTreeLeaf
37
     */
38
    protected $tree = null;
39
40
    /**
41
     * @var int
42
     */
43
    protected $maxDepth;
44
45
    /**
46
     * @var int
47
     */
48
    public $actualDepth = 0;
49
50
    /**
51
     * @var int
52
     */
53
    private $numUsableFeatures = 0;
54
55
    /**
56
     * @var array
57
     */
58
    private $selectedFeatures;
59
60
    /**
61
     * @var array
62
     */
63
    private $featureImportances = null;
64
65
    /**
66
     *
67
     * @var array
68
     */
69
    private $columnNames = null;
70
71
    /**
72
     * @param int $maxDepth
73
     */
74
    public function __construct(int $maxDepth = 10)
75
    {
76
        $this->maxDepth = $maxDepth;
77
    }
78
79
    /**
80
     * @param array $samples
81
     * @param array $targets
82
     */
83
    public function train(array $samples, array $targets)
84
    {
85
        $this->samples = array_merge($this->samples, $samples);
86
        $this->targets = array_merge($this->targets, $targets);
87
88
        $this->featureCount = count($this->samples[0]);
89
        $this->columnTypes = self::getColumnTypes($this->samples);
90
        $this->labels = array_keys(array_count_values($this->targets));
91
        $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
92
93
        // Each time the tree is trained, feature importances are reset so that
94
        // we will have to compute it again depending on the new data
95
        $this->featureImportances = null;
0 ignored issues
show
Documentation Bug introduced by
It seems like null of type null is incompatible with the declared type array of property $featureImportances.

Our type inference engine has found an assignment to a property that is incompatible with the declared type of that property.

Either this assignment is in error or the assigned type should be added to the documentation/type hint for that property..

Loading history...
96
97
        // If column names are given or computed before, then there is no
98
        // need to init it and accidentally remove the previous given names
99
        if ($this->columnNames === null) {
100
            $this->columnNames = range(0, $this->featureCount - 1);
101
        } elseif (count($this->columnNames) > $this->featureCount) {
102
            $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
103
        } elseif (count($this->columnNames) < $this->featureCount) {
104
            $this->columnNames = array_merge(
105
                $this->columnNames,
106
                range(count($this->columnNames), $this->featureCount - 1)
107
            );
108
        }
109
    }
110
111
    /**
112
     * @param array $samples
113
     *
114
     * @return array
115
     */
116
    public static function getColumnTypes(array $samples) : array
117
    {
118
        $types = [];
119
        $featureCount = count($samples[0]);
120
        for ($i = 0; $i < $featureCount; ++$i) {
121
            $values = array_column($samples, $i);
122
            $isCategorical = self::isCategoricalColumn($values);
123
            $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS;
124
        }
125
126
        return $types;
127
    }
128
129
    /**
130
     * @param array $records
131
     * @param int   $depth
132
     *
133
     * @return DecisionTreeLeaf
134
     */
135
    protected function getSplitLeaf(array $records, int $depth = 0) : DecisionTreeLeaf
136
    {
137
        $split = $this->getBestSplit($records);
138
        $split->level = $depth;
139
        if ($this->actualDepth < $depth) {
140
            $this->actualDepth = $depth;
141
        }
142
143
        // Traverse all records to see if all records belong to the same class,
144
        // otherwise group the records so that we can classify the leaf
145
        // in case maximum depth is reached
146
        $leftRecords = [];
147
        $rightRecords= [];
148
        $remainingTargets = [];
149
        $prevRecord = null;
150
        $allSame = true;
151
152
        foreach ($records as $recordNo) {
153
            // Check if the previous record is the same with the current one
154
            $record = $this->samples[$recordNo];
155
            if ($prevRecord && $prevRecord != $record) {
156
                $allSame = false;
157
            }
158
            $prevRecord = $record;
159
160
            // According to the split criteron, this record will
161
            // belong to either left or the right side in the next split
162
            if ($split->evaluate($record)) {
163
                $leftRecords[] = $recordNo;
164
            } else {
165
                $rightRecords[]= $recordNo;
166
            }
167
168
            // Group remaining targets
169
            $target = $this->targets[$recordNo];
170
            if (!array_key_exists($target, $remainingTargets)) {
171
                $remainingTargets[$target] = 1;
172
            } else {
173
                ++$remainingTargets[$target];
174
            }
175
        }
176
177
        if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) {
178
            $split->isTerminal = 1;
0 ignored issues
show
Documentation Bug introduced by
The property $isTerminal was declared of type boolean, but 1 is of type integer. Maybe add a type cast?

This check looks for assignments to scalar types that may be of the wrong type.

To ensure the code behaves as expected, it may be a good idea to add an explicit type cast.

$answer = 42;

$correct = false;

$correct = (bool) $answer;
Loading history...
179
            arsort($remainingTargets);
180
            $split->classValue = key($remainingTargets);
181
        } else {
182
            if ($leftRecords) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $leftRecords of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
183
                $split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);
184
            }
185
            if ($rightRecords) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $rightRecords of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
186
                $split->rightLeaf= $this->getSplitLeaf($rightRecords, $depth + 1);
187
            }
188
        }
189
190
        return $split;
191
    }
192
193
    /**
194
     * @param array $records
195
     *
196
     * @return DecisionTreeLeaf
197
     */
198
    protected function getBestSplit(array $records) : DecisionTreeLeaf
199
    {
200
        $targets = array_intersect_key($this->targets, array_flip($records));
201
        $samples = array_intersect_key($this->samples, array_flip($records));
202
        $samples = array_combine($records, $this->preprocess($samples));
203
        $bestGiniVal = 1;
204
        $bestSplit = null;
205
        $features = $this->getSelectedFeatures();
206
        foreach ($features as $i) {
207
            $colValues = [];
208
            foreach ($samples as $index => $row) {
209
                $colValues[$index] = $row[$i];
210
            }
211
            $counts = array_count_values($colValues);
212
            arsort($counts);
213
            $baseValue = key($counts);
214
            $gini = $this->getGiniIndex($baseValue, $colValues, $targets);
215
            if ($bestSplit === null || $bestGiniVal > $gini) {
216
                $split = new DecisionTreeLeaf();
217
                $split->value = $baseValue;
218
                $split->giniIndex = $gini;
219
                $split->columnIndex = $i;
220
                $split->isContinuous = $this->columnTypes[$i] == self::CONTINUOUS;
221
                $split->records = $records;
222
223
                // If a numeric column is to be selected, then
224
                // the original numeric value and the selected operator
225
                // will also be saved into the leaf for future access
226
                if ($this->columnTypes[$i] == self::CONTINUOUS) {
227
                    $matches = [];
228
                    preg_match("/^([<>=]{1,2})\s*(.*)/", strval($split->value), $matches);
229
                    $split->operator = $matches[1];
230
                    $split->numericValue = floatval($matches[2]);
231
                }
232
233
                $bestSplit = $split;
234
                $bestGiniVal = $gini;
235
            }
236
        }
237
238
        return $bestSplit;
239
    }
240
241
    /**
242
     * Returns available features/columns to the tree for the decision making
243
     * process. <br>
244
     *
245
     * If a number is given with setNumFeatures() method, then a random selection
246
     * of features up to this number is returned. <br>
247
     *
248
     * If some features are manually selected by use of setSelectedFeatures(),
249
     * then only these features are returned <br>
250
     *
251
     * If any of above methods were not called beforehand, then all features
252
     * are returned by default.
253
     *
254
     * @return array
255
     */
256
    protected function getSelectedFeatures() : array
257
    {
258
        $allFeatures = range(0, $this->featureCount - 1);
259
        if ($this->numUsableFeatures === 0 && !$this->selectedFeatures) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $this->selectedFeatures of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
260
            return $allFeatures;
261
        }
262
263
        if ($this->selectedFeatures) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $this->selectedFeatures of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
264
            return $this->selectedFeatures;
265
        }
266
267
        $numFeatures = $this->numUsableFeatures;
268
        if ($numFeatures > $this->featureCount) {
269
            $numFeatures = $this->featureCount;
270
        }
271
        shuffle($allFeatures);
272
        $selectedFeatures = array_slice($allFeatures, 0, $numFeatures, false);
273
        sort($selectedFeatures);
274
275
        return $selectedFeatures;
276
    }
277
278
    /**
279
     * @param mixed $baseValue
280
     * @param array $colValues
281
     * @param array $targets
282
     *
283
     * @return float
284
     */
285
    public function getGiniIndex($baseValue, array $colValues, array $targets) : float
286
    {
287
        $countMatrix = [];
288
        foreach ($this->labels as $label) {
289
            $countMatrix[$label] = [0, 0];
290
        }
291
292
        foreach ($colValues as $index => $value) {
293
            $label = $targets[$index];
294
            $rowIndex = $value === $baseValue ? 0 : 1;
295
            ++$countMatrix[$label][$rowIndex];
296
        }
297
298
        $giniParts = [0, 0];
299
        for ($i = 0; $i <= 1; ++$i) {
300
            $part = 0;
301
            $sum = array_sum(array_column($countMatrix, $i));
302
            if ($sum > 0) {
303
                foreach ($this->labels as $label) {
304
                    $part += pow($countMatrix[$label][$i] / floatval($sum), 2);
305
                }
306
            }
307
308
            $giniParts[$i] = (1 - $part) * $sum;
309
        }
310
311
        return array_sum($giniParts) / count($colValues);
312
    }
313
314
    /**
315
     * @param array $samples
316
     *
317
     * @return array
318
     */
319
    protected function preprocess(array $samples) : array
320
    {
321
        // Detect and convert continuous data column values into
322
        // discrete values by using the median as a threshold value
323
        $columns = [];
324
        for ($i = 0; $i < $this->featureCount; ++$i) {
325
            $values = array_column($samples, $i);
326
            if ($this->columnTypes[$i] == self::CONTINUOUS) {
327
                $median = Mean::median($values);
328
                foreach ($values as &$value) {
329
                    if ($value <= $median) {
330
                        $value = "<= $median";
331
                    } else {
332
                        $value = "> $median";
333
                    }
334
                }
335
            }
336
            $columns[] = $values;
337
        }
338
        // Below method is a strange yet very simple & efficient method
339
        // to get the transpose of a 2D array
340
        return array_map(null, ...$columns);
341
    }
342
343
    /**
344
     * @param array $columnValues
345
     *
346
     * @return bool
347
     */
348
    protected static function isCategoricalColumn(array $columnValues) : bool
349
    {
350
        $count = count($columnValues);
351
352
        // There are two main indicators that *may* show whether a
353
        // column is composed of discrete set of values:
354
        // 1- Column may contain string values and non-float values
355
        // 2- Number of unique values in the column is only a small fraction of
356
        //	  all values in that column (Lower than or equal to %20 of all values)
357
        $numericValues = array_filter($columnValues, 'is_numeric');
358
        $floatValues = array_filter($columnValues, 'is_float');
359
        if ($floatValues) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $floatValues of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
360
            return false;
361
        }
362
363
        if (count($numericValues) !== $count) {
364
            return true;
365
        }
366
367
        $distinctValues = array_count_values($columnValues);
368
369
        return count($distinctValues) <= $count / 5;
370
    }
371
372
    /**
373
     * This method is used to set number of columns to be used
374
     * when deciding a split at an internal node of the tree.  <br>
375
     * If the value is given 0, then all features are used (default behaviour),
376
     * otherwise the given value will be used as a maximum for number of columns
377
     * randomly selected for each split operation.
378
     *
379
     * @param int $numFeatures
380
     *
381
     * @return $this
382
     *
383
     * @throws InvalidArgumentException
384
     */
385
    public function setNumFeatures(int $numFeatures)
386
    {
387
        if ($numFeatures < 0) {
388
            throw new InvalidArgumentException('Selected column count should be greater or equal to zero');
389
        }
390
391
        $this->numUsableFeatures = $numFeatures;
392
393
        return $this;
394
    }
395
396
    /**
397
     * Used to set predefined features to consider while deciding which column to use for a split
398
     *
399
     * @param array $selectedFeatures
400
     */
401
    protected function setSelectedFeatures(array $selectedFeatures)
402
    {
403
        $this->selectedFeatures = $selectedFeatures;
404
    }
405
406
    /**
407
     * A string array to represent columns. Useful when HTML output or
408
     * column importances are desired to be inspected.
409
     *
410
     * @param array $names
411
     *
412
     * @return $this
413
     *
414
     * @throws InvalidArgumentException
415
     */
416
    public function setColumnNames(array $names)
417
    {
418
        if ($this->featureCount !== 0 && count($names) !== $this->featureCount) {
419
            throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount));
420
        }
421
422
        $this->columnNames = $names;
423
424
        return $this;
425
    }
426
427
    /**
428
     * @return string
429
     */
430
    public function getHtml()
431
    {
432
        return $this->tree->getHTML($this->columnNames);
433
    }
434
435
    /**
436
     * This will return an array including an importance value for
437
     * each column in the given dataset. The importance values are
438
     * normalized and their total makes 1.<br/>
439
     *
440
     * @return array
441
     */
442
    public function getFeatureImportances()
443
    {
444
        if ($this->featureImportances !== null) {
445
            return $this->featureImportances;
446
        }
447
448
        $sampleCount = count($this->samples);
449
        $this->featureImportances = [];
450
        foreach ($this->columnNames as $column => $columnName) {
451
            $nodes = $this->getSplitNodesByColumn($column, $this->tree);
452
453
            $importance = 0;
454
            foreach ($nodes as $node) {
455
                $importance += $node->getNodeImpurityDecrease($sampleCount);
456
            }
457
458
            $this->featureImportances[$columnName] = $importance;
459
        }
460
461
        // Normalize & sort the importances
462
        $total = array_sum($this->featureImportances);
463
        if ($total > 0) {
464
            foreach ($this->featureImportances as &$importance) {
465
                $importance /= $total;
466
            }
467
            arsort($this->featureImportances);
468
        }
469
470
        return $this->featureImportances;
471
    }
472
473
    /**
474
     * Collects and returns an array of internal nodes that use the given
475
     * column as a split criterion
476
     *
477
     * @param int              $column
478
     * @param DecisionTreeLeaf $node
479
     *
480
     * @return array
481
     */
482
    protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node) : array
483
    {
484
        if (!$node || $node->isTerminal) {
485
            return [];
486
        }
487
488
        $nodes = [];
489
        if ($node->columnIndex === $column) {
490
            $nodes[] = $node;
491
        }
492
493
        $lNodes = [];
494
        $rNodes = [];
495
        if ($node->leftLeaf) {
496
            $lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
497
        }
498
499
        if ($node->rightLeaf) {
500
            $rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
501
        }
502
503
        $nodes = array_merge($nodes, $lNodes, $rNodes);
504
505
        return $nodes;
506
    }
507
508
    /**
509
     * @param array $sample
510
     *
511
     * @return mixed
512
     */
513
    protected function predictSample(array $sample)
514
    {
515
        $node = $this->tree;
516
        do {
517
            if ($node->isTerminal) {
518
                break;
519
            }
520
521
            if ($node->evaluate($sample)) {
522
                $node = $node->leftLeaf;
523
            } else {
524
                $node = $node->rightLeaf;
525
            }
526
        } while ($node);
527
528
        return $node ? $node->classValue : $this->labels[0];
529
    }
530
}
531