Passed
Push — master ( e83f7b...d953ef )
by Arkadiusz
03:28
created

src/Phpml/Classification/DecisionTree.php (1 issue)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
8
use Phpml\Exception\InvalidArgumentException;
9
use Phpml\Helper\Predictable;
10
use Phpml\Helper\Trainable;
11
use Phpml\Math\Statistic\Mean;
12
13
class DecisionTree implements Classifier
14
{
15
    use Trainable, Predictable;
16
17
    public const CONTINUOUS = 1;
18
19
    public const NOMINAL = 2;
20
21
    /**
22
     * @var int
23
     */
24
    public $actualDepth = 0;
25
26
    /**
27
     * @var array
28
     */
29
    protected $columnTypes = [];
30
31
    /**
32
     * @var DecisionTreeLeaf
33
     */
34
    protected $tree = null;
35
36
    /**
37
     * @var int
38
     */
39
    protected $maxDepth;
40
41
    /**
42
     * @var array
43
     */
44
    private $labels = [];
45
46
    /**
47
     * @var int
48
     */
49
    private $featureCount = 0;
50
51
    /**
52
     * @var int
53
     */
54
    private $numUsableFeatures = 0;
55
56
    /**
57
     * @var array
58
     */
59
    private $selectedFeatures = [];
60
61
    /**
62
     * @var array|null
63
     */
64
    private $featureImportances;
65
66
    /**
67
     * @var array
68
     */
69
    private $columnNames = [];
70
71
    public function __construct(int $maxDepth = 10)
72
    {
73
        $this->maxDepth = $maxDepth;
74
    }
75
76
    public function train(array $samples, array $targets): void
77
    {
78
        $this->samples = array_merge($this->samples, $samples);
79
        $this->targets = array_merge($this->targets, $targets);
80
81
        $this->featureCount = count($this->samples[0]);
82
        $this->columnTypes = self::getColumnTypes($this->samples);
83
        $this->labels = array_keys(array_count_values($this->targets));
84
        $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
85
86
        // Each time the tree is trained, feature importances are reset so that
87
        // we will have to compute it again depending on the new data
88
        $this->featureImportances = null;
89
90
        // If column names are given or computed before, then there is no
91
        // need to init it and accidentally remove the previous given names
92
        if ($this->columnNames === []) {
93
            $this->columnNames = range(0, $this->featureCount - 1);
94
        } elseif (count($this->columnNames) > $this->featureCount) {
95
            $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
96
        } elseif (count($this->columnNames) < $this->featureCount) {
97
            $this->columnNames = array_merge(
98
                $this->columnNames,
99
                range(count($this->columnNames), $this->featureCount - 1)
100
            );
101
        }
102
    }
103
104
    public static function getColumnTypes(array $samples): array
105
    {
106
        $types = [];
107
        $featureCount = count($samples[0]);
108
        for ($i = 0; $i < $featureCount; ++$i) {
109
            $values = array_column($samples, $i);
110
            $isCategorical = self::isCategoricalColumn($values);
111
            $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS;
112
        }
113
114
        return $types;
115
    }
116
117
    /**
118
     * @param mixed $baseValue
119
     */
120
    public function getGiniIndex($baseValue, array $colValues, array $targets): float
121
    {
122
        $countMatrix = [];
123
        foreach ($this->labels as $label) {
124
            $countMatrix[$label] = [0, 0];
125
        }
126
127
        foreach ($colValues as $index => $value) {
128
            $label = $targets[$index];
129
            $rowIndex = $value === $baseValue ? 0 : 1;
130
            ++$countMatrix[$label][$rowIndex];
131
        }
132
133
        $giniParts = [0, 0];
134
        for ($i = 0; $i <= 1; ++$i) {
135
            $part = 0;
136
            $sum = array_sum(array_column($countMatrix, $i));
137
            if ($sum > 0) {
138
                foreach ($this->labels as $label) {
139
                    $part += pow($countMatrix[$label][$i] / (float) $sum, 2);
140
                }
141
            }
142
143
            $giniParts[$i] = (1 - $part) * $sum;
144
        }
145
146
        return array_sum($giniParts) / count($colValues);
147
    }
148
149
    /**
150
     * This method is used to set number of columns to be used
151
     * when deciding a split at an internal node of the tree.  <br>
152
     * If the value is given 0, then all features are used (default behaviour),
153
     * otherwise the given value will be used as a maximum for number of columns
154
     * randomly selected for each split operation.
155
     *
156
     * @return $this
157
     *
158
     * @throws InvalidArgumentException
159
     */
160
    public function setNumFeatures(int $numFeatures)
161
    {
162
        if ($numFeatures < 0) {
163
            throw new InvalidArgumentException('Selected column count should be greater or equal to zero');
164
        }
165
166
        $this->numUsableFeatures = $numFeatures;
167
168
        return $this;
169
    }
170
171
    /**
172
     * A string array to represent columns. Useful when HTML output or
173
     * column importances are desired to be inspected.
174
     *
175
     * @return $this
176
     *
177
     * @throws InvalidArgumentException
178
     */
179
    public function setColumnNames(array $names)
180
    {
181
        if ($this->featureCount !== 0 && count($names) !== $this->featureCount) {
182
            throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount));
183
        }
184
185
        $this->columnNames = $names;
186
187
        return $this;
188
    }
189
190
    public function getHtml(): string
191
    {
192
        return $this->tree->getHTML($this->columnNames);
193
    }
194
195
    /**
196
     * This will return an array including an importance value for
197
     * each column in the given dataset. The importance values are
198
     * normalized and their total makes 1.<br/>
199
     */
200
    public function getFeatureImportances(): array
201
    {
202
        if ($this->featureImportances !== null) {
203
            return $this->featureImportances;
204
        }
205
206
        $sampleCount = count($this->samples);
207
        $this->featureImportances = [];
208
        foreach ($this->columnNames as $column => $columnName) {
209
            $nodes = $this->getSplitNodesByColumn($column, $this->tree);
210
211
            $importance = 0;
212
            foreach ($nodes as $node) {
213
                $importance += $node->getNodeImpurityDecrease($sampleCount);
214
            }
215
216
            $this->featureImportances[$columnName] = $importance;
217
        }
218
219
        // Normalize & sort the importances
220
        $total = array_sum($this->featureImportances);
221
        if ($total > 0) {
222
            foreach ($this->featureImportances as &$importance) {
223
                $importance /= $total;
224
            }
225
226
            arsort($this->featureImportances);
227
        }
228
229
        return $this->featureImportances;
230
    }
231
232
    protected function getSplitLeaf(array $records, int $depth = 0): DecisionTreeLeaf
233
    {
234
        $split = $this->getBestSplit($records);
235
        $split->level = $depth;
236
        if ($this->actualDepth < $depth) {
237
            $this->actualDepth = $depth;
238
        }
239
240
        // Traverse all records to see if all records belong to the same class,
241
        // otherwise group the records so that we can classify the leaf
242
        // in case maximum depth is reached
243
        $leftRecords = [];
244
        $rightRecords = [];
245
        $remainingTargets = [];
246
        $prevRecord = null;
247
        $allSame = true;
248
249
        foreach ($records as $recordNo) {
250
            // Check if the previous record is the same with the current one
251
            $record = $this->samples[$recordNo];
252
            if ($prevRecord && $prevRecord != $record) {
253
                $allSame = false;
254
            }
255
256
            $prevRecord = $record;
257
258
            // According to the split criteron, this record will
259
            // belong to either left or the right side in the next split
260
            if ($split->evaluate($record)) {
261
                $leftRecords[] = $recordNo;
262
            } else {
263
                $rightRecords[] = $recordNo;
264
            }
265
266
            // Group remaining targets
267
            $target = $this->targets[$recordNo];
268
            if (!array_key_exists($target, $remainingTargets)) {
269
                $remainingTargets[$target] = 1;
270
            } else {
271
                ++$remainingTargets[$target];
272
            }
273
        }
274
275
        if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) {
276
            $split->isTerminal = 1;
277
            arsort($remainingTargets);
278
            $split->classValue = key($remainingTargets);
279
        } else {
280
            if ($leftRecords) {
281
                $split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);
282
            }
283
284
            if ($rightRecords) {
285
                $split->rightLeaf = $this->getSplitLeaf($rightRecords, $depth + 1);
286
            }
287
        }
288
289
        return $split;
290
    }
291
292
    protected function getBestSplit(array $records): DecisionTreeLeaf
293
    {
294
        $targets = array_intersect_key($this->targets, array_flip($records));
295
        $samples = array_intersect_key($this->samples, array_flip($records));
296
        $samples = array_combine($records, $this->preprocess($samples));
297
        $bestGiniVal = 1;
298
        $bestSplit = null;
299
        $features = $this->getSelectedFeatures();
300
        foreach ($features as $i) {
301
            $colValues = [];
302
            foreach ($samples as $index => $row) {
303
                $colValues[$index] = $row[$i];
304
            }
305
306
            $counts = array_count_values($colValues);
307
            arsort($counts);
308
            $baseValue = key($counts);
309
            $gini = $this->getGiniIndex($baseValue, $colValues, $targets);
310
            if ($bestSplit === null || $bestGiniVal > $gini) {
311
                $split = new DecisionTreeLeaf();
312
                $split->value = $baseValue;
313
                $split->giniIndex = $gini;
314
                $split->columnIndex = $i;
315
                $split->isContinuous = $this->columnTypes[$i] == self::CONTINUOUS;
316
                $split->records = $records;
317
318
                // If a numeric column is to be selected, then
319
                // the original numeric value and the selected operator
320
                // will also be saved into the leaf for future access
321
                if ($this->columnTypes[$i] == self::CONTINUOUS) {
322
                    $matches = [];
323
                    preg_match("/^([<>=]{1,2})\s*(.*)/", (string) $split->value, $matches);
324
                    $split->operator = $matches[1];
325
                    $split->numericValue = (float) $matches[2];
326
                }
327
328
                $bestSplit = $split;
329
                $bestGiniVal = $gini;
330
            }
331
        }
332
333
        return $bestSplit;
334
    }
335
336
    /**
337
     * Returns available features/columns to the tree for the decision making
338
     * process. <br>
339
     *
340
     * If a number is given with setNumFeatures() method, then a random selection
341
     * of features up to this number is returned. <br>
342
     *
343
     * If some features are manually selected by use of setSelectedFeatures(),
344
     * then only these features are returned <br>
345
     *
346
     * If any of above methods were not called beforehand, then all features
347
     * are returned by default.
348
     */
349
    protected function getSelectedFeatures(): array
350
    {
351
        $allFeatures = range(0, $this->featureCount - 1);
352
        if ($this->numUsableFeatures === 0 && !$this->selectedFeatures) {
0 ignored issues
show
Bug Best Practice introduced by Arkadiusz Kondas
The expression $this->selectedFeatures of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
353
            return $allFeatures;
354
        }
355
356
        if ($this->selectedFeatures) {
357
            return $this->selectedFeatures;
358
        }
359
360
        $numFeatures = $this->numUsableFeatures;
361
        if ($numFeatures > $this->featureCount) {
362
            $numFeatures = $this->featureCount;
363
        }
364
365
        shuffle($allFeatures);
366
        $selectedFeatures = array_slice($allFeatures, 0, $numFeatures, false);
367
        sort($selectedFeatures);
368
369
        return $selectedFeatures;
370
    }
371
372
    protected function preprocess(array $samples): array
373
    {
374
        // Detect and convert continuous data column values into
375
        // discrete values by using the median as a threshold value
376
        $columns = [];
377
        for ($i = 0; $i < $this->featureCount; ++$i) {
378
            $values = array_column($samples, $i);
379
            if ($this->columnTypes[$i] == self::CONTINUOUS) {
380
                $median = Mean::median($values);
381
                foreach ($values as &$value) {
382
                    if ($value <= $median) {
383
                        $value = "<= ${median}";
384
                    } else {
385
                        $value = "> ${median}";
386
                    }
387
                }
388
            }
389
390
            $columns[] = $values;
391
        }
392
393
        // Below method is a strange yet very simple & efficient method
394
        // to get the transpose of a 2D array
395
        return array_map(null, ...$columns);
396
    }
397
398
    protected static function isCategoricalColumn(array $columnValues): bool
399
    {
400
        $count = count($columnValues);
401
402
        // There are two main indicators that *may* show whether a
403
        // column is composed of discrete set of values:
404
        // 1- Column may contain string values and non-float values
405
        // 2- Number of unique values in the column is only a small fraction of
406
        //	  all values in that column (Lower than or equal to %20 of all values)
407
        $numericValues = array_filter($columnValues, 'is_numeric');
408
        $floatValues = array_filter($columnValues, 'is_float');
409
        if ($floatValues) {
410
            return false;
411
        }
412
413
        if (count($numericValues) !== $count) {
414
            return true;
415
        }
416
417
        $distinctValues = array_count_values($columnValues);
418
419
        return count($distinctValues) <= $count / 5;
420
    }
421
422
    /**
423
     * Used to set predefined features to consider while deciding which column to use for a split
424
     */
425
    protected function setSelectedFeatures(array $selectedFeatures): void
426
    {
427
        $this->selectedFeatures = $selectedFeatures;
428
    }
429
430
    /**
431
     * Collects and returns an array of internal nodes that use the given
432
     * column as a split criterion
433
     */
434
    protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node): array
435
    {
436
        if (!$node || $node->isTerminal) {
437
            return [];
438
        }
439
440
        $nodes = [];
441
        if ($node->columnIndex === $column) {
442
            $nodes[] = $node;
443
        }
444
445
        $lNodes = [];
446
        $rNodes = [];
447
        if ($node->leftLeaf) {
448
            $lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
449
        }
450
451
        if ($node->rightLeaf) {
452
            $rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
453
        }
454
455
        $nodes = array_merge($nodes, $lNodes, $rNodes);
456
457
        return $nodes;
458
    }
459
460
    /**
461
     * @return mixed
462
     */
463
    protected function predictSample(array $sample)
464
    {
465
        $node = $this->tree;
466
        do {
467
            if ($node->isTerminal) {
468
                break;
469
            }
470
471
            if ($node->evaluate($sample)) {
472
                $node = $node->leftLeaf;
473
            } else {
474
                $node = $node->rightLeaf;
475
            }
476
        } while ($node);
477
478
        return $node ? $node->classValue : $this->labels[0];
479
    }
480
}
481