Passed
Push — master ( 01bb82...39747e )
by Arkadiusz
03:07
created

DecisionTree::getHtml()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 4
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
eloc 2
nc 1
nop 0
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Exception\InvalidArgumentException;
8
use Phpml\Helper\Predictable;
9
use Phpml\Helper\Trainable;
10
use Phpml\Math\Statistic\Mean;
11
use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
12
13
class DecisionTree implements Classifier
14
{
15
    use Trainable, Predictable;
16
17
    const CONTINUOUS = 1;
18
    const NOMINAL = 2;
19
20
    /**
21
     * @var array
22
     */
23
    protected $columnTypes;
24
25
    /**
26
     * @var array
27
     */
28
    private $labels = [];
29
30
    /**
31
     * @var int
32
     */
33
    private $featureCount = 0;
34
35
    /**
36
     * @var DecisionTreeLeaf
37
     */
38
    protected $tree = null;
39
40
    /**
41
     * @var int
42
     */
43
    protected $maxDepth;
44
45
    /**
46
     * @var int
47
     */
48
    public $actualDepth = 0;
49
50
    /**
51
     * @var int
52
     */
53
    private $numUsableFeatures = 0;
54
55
    /**
56
     * @var array
57
     */
58
    private $selectedFeatures;
59
60
    /**
61
     * @var array
62
     */
63
    private $featureImportances = null;
64
65
    /**
66
     *
67
     * @var array
68
     */
69
    private $columnNames = null;
70
71
    /**
72
     * @param int $maxDepth
73
     */
74
    public function __construct(int $maxDepth = 10)
75
    {
76
        $this->maxDepth = $maxDepth;
77
    }
78
79
    /**
80
     * @param array $samples
81
     * @param array $targets
82
     */
83
    public function train(array $samples, array $targets)
84
    {
85
        $this->samples = array_merge($this->samples, $samples);
86
        $this->targets = array_merge($this->targets, $targets);
87
88
        $this->featureCount = count($this->samples[0]);
89
        $this->columnTypes = self::getColumnTypes($this->samples);
90
        $this->labels = array_keys(array_count_values($this->targets));
91
        $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
92
93
        // Each time the tree is trained, feature importances are reset so that
94
        // we will have to compute it again depending on the new data
95
        $this->featureImportances = null;
0 ignored issues
show
Documentation Bug introduced by
It seems like null of type null is incompatible with the declared type array of property $featureImportances.

Our type inference engine has found an assignment to a property that is incompatible with the declared type of that property.

Either this assignment is in error or the assigned type should be added to the documentation/type hint for that property..

Loading history...
96
97
        // If column names are given or computed before, then there is no
98
        // need to init it and accidentally remove the previous given names
99
        if ($this->columnNames === null) {
100
            $this->columnNames = range(0, $this->featureCount - 1);
101
        } elseif (count($this->columnNames) > $this->featureCount) {
102
            $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
103
        } elseif (count($this->columnNames) < $this->featureCount) {
104
            $this->columnNames = array_merge($this->columnNames,
105
                range(count($this->columnNames), $this->featureCount - 1));
106
        }
107
    }
108
109
    /**
110
     * @param array $samples
111
     * @return array
112
     */
113
    public static function getColumnTypes(array $samples) : array
114
    {
115
        $types = [];
116
        $featureCount = count($samples[0]);
117
        for ($i=0; $i < $featureCount; $i++) {
118
            $values = array_column($samples, $i);
119
            $isCategorical = self::isCategoricalColumn($values);
120
            $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS;
121
        }
122
123
        return $types;
124
    }
125
126
    /**
127
     * @param array $records
128
     * @param int $depth
129
     * @return DecisionTreeLeaf
130
     */
131
    protected function getSplitLeaf(array $records, int $depth = 0) : DecisionTreeLeaf
132
    {
133
        $split = $this->getBestSplit($records);
134
        $split->level = $depth;
135
        if ($this->actualDepth < $depth) {
136
            $this->actualDepth = $depth;
137
        }
138
139
        // Traverse all records to see if all records belong to the same class,
140
        // otherwise group the records so that we can classify the leaf
141
        // in case maximum depth is reached
142
        $leftRecords = [];
143
        $rightRecords= [];
144
        $remainingTargets = [];
145
        $prevRecord = null;
146
        $allSame = true;
147
148
        foreach ($records as $recordNo) {
149
            // Check if the previous record is the same with the current one
150
            $record = $this->samples[$recordNo];
151
            if ($prevRecord && $prevRecord != $record) {
152
                $allSame = false;
153
            }
154
            $prevRecord = $record;
155
156
            // According to the split criteron, this record will
157
            // belong to either left or the right side in the next split
158
            if ($split->evaluate($record)) {
159
                $leftRecords[] = $recordNo;
160
            } else {
161
                $rightRecords[]= $recordNo;
162
            }
163
164
            // Group remaining targets
165
            $target = $this->targets[$recordNo];
166
            if (! array_key_exists($target, $remainingTargets)) {
167
                $remainingTargets[$target] = 1;
168
            } else {
169
                $remainingTargets[$target]++;
170
            }
171
        }
172
173
        if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) {
174
            $split->isTerminal = 1;
0 ignored issues
show
Documentation Bug introduced by
The property $isTerminal was declared of type boolean, but 1 is of type integer. Maybe add a type cast?

This check looks for assignments to scalar types that may be of the wrong type.

To ensure the code behaves as expected, it may be a good idea to add an explicit type cast.

$answer = 42;

$correct = false;

$correct = (bool) $answer;
Loading history...
175
            arsort($remainingTargets);
176
            $split->classValue = key($remainingTargets);
177
        } else {
178
            if ($leftRecords) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $leftRecords of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
179
                $split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);
180
            }
181
            if ($rightRecords) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $rightRecords of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
182
                $split->rightLeaf= $this->getSplitLeaf($rightRecords, $depth + 1);
183
            }
184
        }
185
186
        return $split;
187
    }
188
189
    /**
190
     * @param array $records
191
     * @return DecisionTreeLeaf
192
     */
193
    protected function getBestSplit(array $records) : DecisionTreeLeaf
194
    {
195
        $targets = array_intersect_key($this->targets, array_flip($records));
196
        $samples = array_intersect_key($this->samples, array_flip($records));
197
        $samples = array_combine($records, $this->preprocess($samples));
198
        $bestGiniVal = 1;
199
        $bestSplit = null;
200
        $features = $this->getSelectedFeatures();
201
        foreach ($features as $i) {
202
            $colValues = [];
203
            foreach ($samples as $index => $row) {
204
                $colValues[$index] = $row[$i];
205
            }
206
            $counts = array_count_values($colValues);
207
            arsort($counts);
208
            $baseValue = key($counts);
209
            $gini = $this->getGiniIndex($baseValue, $colValues, $targets);
210
            if ($bestSplit === null || $bestGiniVal > $gini) {
211
                $split = new DecisionTreeLeaf();
212
                $split->value = $baseValue;
213
                $split->giniIndex = $gini;
214
                $split->columnIndex = $i;
215
                $split->isContinuous = $this->columnTypes[$i] == self::CONTINUOUS;
216
                $split->records = $records;
217
218
                // If a numeric column is to be selected, then
219
                // the original numeric value and the selected operator
220
                // will also be saved into the leaf for future access
221
                if ($this->columnTypes[$i] == self::CONTINUOUS) {
222
                    $matches = [];
223
                    preg_match("/^([<>=]{1,2})\s*(.*)/", strval($split->value), $matches);
224
                    $split->operator = $matches[1];
225
                    $split->numericValue = floatval($matches[2]);
226
                }
227
228
                $bestSplit = $split;
229
                $bestGiniVal = $gini;
230
            }
231
        }
232
233
        return $bestSplit;
234
    }
235
236
    /**
237
     * Returns available features/columns to the tree for the decision making
238
     * process. <br>
239
     *
240
     * If a number is given with setNumFeatures() method, then a random selection
241
     * of features up to this number is returned. <br>
242
     *
243
     * If some features are manually selected by use of setSelectedFeatures(),
244
     * then only these features are returned <br>
245
     *
246
     * If any of above methods were not called beforehand, then all features
247
     * are returned by default.
248
     *
249
     * @return array
250
     */
251
    protected function getSelectedFeatures() : array
252
    {
253
        $allFeatures = range(0, $this->featureCount - 1);
254
        if ($this->numUsableFeatures === 0 && ! $this->selectedFeatures) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $this->selectedFeatures of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
255
            return $allFeatures;
256
        }
257
258
        if ($this->selectedFeatures) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $this->selectedFeatures of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
259
            return $this->selectedFeatures;
260
        }
261
262
        $numFeatures = $this->numUsableFeatures;
263
        if ($numFeatures > $this->featureCount) {
264
            $numFeatures = $this->featureCount;
265
        }
266
        shuffle($allFeatures);
267
        $selectedFeatures = array_slice($allFeatures, 0, $numFeatures, false);
268
        sort($selectedFeatures);
269
270
        return $selectedFeatures;
271
    }
272
273
    /**
274
     * @param $baseValue
275
     * @param array $colValues
276
     * @param array $targets
277
     * @return float
278
     */
279
    public function getGiniIndex($baseValue, array $colValues, array $targets) : float
280
    {
281
        $countMatrix = [];
282
        foreach ($this->labels as $label) {
283
            $countMatrix[$label] = [0, 0];
284
        }
285
        foreach ($colValues as $index => $value) {
286
            $label = $targets[$index];
287
            $rowIndex = $value === $baseValue ? 0 : 1;
288
            $countMatrix[$label][$rowIndex]++;
289
        }
290
        $giniParts = [0, 0];
291
        for ($i=0; $i<=1; $i++) {
292
            $part = 0;
293
            $sum = array_sum(array_column($countMatrix, $i));
294
            if ($sum > 0) {
295
                foreach ($this->labels as $label) {
296
                    $part += pow($countMatrix[$label][$i] / floatval($sum), 2);
297
                }
298
            }
299
            $giniParts[$i] = (1 - $part) * $sum;
300
        }
301
302
        return array_sum($giniParts) / count($colValues);
303
    }
304
305
    /**
306
     * @param array $samples
307
     * @return array
308
     */
309
    protected function preprocess(array $samples) : array
310
    {
311
        // Detect and convert continuous data column values into
312
        // discrete values by using the median as a threshold value
313
        $columns = [];
314
        for ($i=0; $i<$this->featureCount; $i++) {
315
            $values = array_column($samples, $i);
316
            if ($this->columnTypes[$i] == self::CONTINUOUS) {
317
                $median = Mean::median($values);
318
                foreach ($values as &$value) {
319
                    if ($value <= $median) {
320
                        $value = "<= $median";
321
                    } else {
322
                        $value = "> $median";
323
                    }
324
                }
325
            }
326
            $columns[] = $values;
327
        }
328
        // Below method is a strange yet very simple & efficient method
329
        // to get the transpose of a 2D array
330
        return array_map(null, ...$columns);
331
    }
332
333
    /**
334
     * @param array $columnValues
335
     * @return bool
336
     */
337
    protected static function isCategoricalColumn(array $columnValues) : bool
338
    {
339
        $count = count($columnValues);
340
341
        // There are two main indicators that *may* show whether a
342
        // column is composed of discrete set of values:
343
        // 1- Column may contain string values and non-float values
344
        // 2- Number of unique values in the column is only a small fraction of
345
        //	  all values in that column (Lower than or equal to %20 of all values)
346
        $numericValues = array_filter($columnValues, 'is_numeric');
347
        $floatValues = array_filter($columnValues, 'is_float');
348
        if ($floatValues) {
0 ignored issues
show
Bug Best Practice introduced by
The expression $floatValues of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using ! empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
349
            return false;
350
        }
351
        if (count($numericValues) !== $count) {
352
            return true;
353
        }
354
355
        $distinctValues = array_count_values($columnValues);
356
357
        return count($distinctValues) <= $count / 5;
358
    }
359
360
    /**
361
     * This method is used to set number of columns to be used
362
     * when deciding a split at an internal node of the tree.  <br>
363
     * If the value is given 0, then all features are used (default behaviour),
364
     * otherwise the given value will be used as a maximum for number of columns
365
     * randomly selected for each split operation.
366
     *
367
     * @param int $numFeatures
368
     * @return $this
369
     * @throws InvalidArgumentException
370
     */
371
    public function setNumFeatures(int $numFeatures)
372
    {
373
        if ($numFeatures < 0) {
374
            throw new InvalidArgumentException('Selected column count should be greater or equal to zero');
375
        }
376
377
        $this->numUsableFeatures = $numFeatures;
378
379
        return $this;
380
    }
381
382
    /**
383
     * Used to set predefined features to consider while deciding which column to use for a split
384
     *
385
     * @param array $selectedFeatures
386
     */
387
    protected function setSelectedFeatures(array $selectedFeatures)
388
    {
389
        $this->selectedFeatures = $selectedFeatures;
390
    }
391
392
    /**
393
     * A string array to represent columns. Useful when HTML output or
394
     * column importances are desired to be inspected.
395
     *
396
     * @param array $names
397
     * @return $this
398
     * @throws InvalidArgumentException
399
     */
400
    public function setColumnNames(array $names)
401
    {
402
        if ($this->featureCount !== 0 && count($names) !== $this->featureCount) {
403
            throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount));
404
        }
405
406
        $this->columnNames = $names;
407
408
        return $this;
409
    }
410
411
    /**
412
     * @return string
413
     */
414
    public function getHtml()
415
    {
416
        return $this->tree->getHTML($this->columnNames);
417
    }
418
419
    /**
420
     * This will return an array including an importance value for
421
     * each column in the given dataset. The importance values are
422
     * normalized and their total makes 1.<br/>
423
     *
424
     * @return array
425
     */
426
    public function getFeatureImportances()
427
    {
428
        if ($this->featureImportances !== null) {
429
            return $this->featureImportances;
430
        }
431
432
        $sampleCount = count($this->samples);
433
        $this->featureImportances = [];
434
        foreach ($this->columnNames as $column => $columnName) {
435
            $nodes = $this->getSplitNodesByColumn($column, $this->tree);
436
437
            $importance = 0;
438
            foreach ($nodes as $node) {
439
                $importance += $node->getNodeImpurityDecrease($sampleCount);
440
            }
441
442
            $this->featureImportances[$columnName] = $importance;
443
        }
444
445
        // Normalize & sort the importances
446
        $total = array_sum($this->featureImportances);
447
        if ($total > 0) {
448
            foreach ($this->featureImportances as &$importance) {
449
                $importance /= $total;
450
            }
451
            arsort($this->featureImportances);
452
        }
453
454
        return $this->featureImportances;
455
    }
456
457
    /**
458
     * Collects and returns an array of internal nodes that use the given
459
     * column as a split criterion
460
     *
461
     * @param int $column
462
     * @param DecisionTreeLeaf $node
463
     * @return array
464
     */
465
    protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node) : array
466
    {
467
        if (!$node || $node->isTerminal) {
468
            return [];
469
        }
470
471
        $nodes = [];
472
        if ($node->columnIndex === $column) {
473
            $nodes[] = $node;
474
        }
475
476
        $lNodes = [];
477
        $rNodes = [];
478
        if ($node->leftLeaf) {
479
            $lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
480
        }
481
        if ($node->rightLeaf) {
482
            $rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
483
        }
484
        $nodes = array_merge($nodes, $lNodes, $rNodes);
485
486
        return $nodes;
487
    }
488
489
    /**
490
     * @param array $sample
491
     * @return mixed
492
     */
493
    protected function predictSample(array $sample)
494
    {
495
        $node = $this->tree;
496
        do {
497
            if ($node->isTerminal) {
498
                break;
499
            }
500
            if ($node->evaluate($sample)) {
501
                $node = $node->leftLeaf;
502
            } else {
503
                $node = $node->rightLeaf;
504
            }
505
        } while ($node);
506
507
        return $node ? $node->classValue : $this->labels[0];
508
    }
509
}
510