DecisionTree::getGiniIndex()   B
last analyzed

Complexity

Conditions 7
Paths 18

Size

Total Lines 27
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 16
dl 0
loc 27
rs 8.8333
c 0
b 0
f 0
cc 7
nc 18
nop 3
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
8
use Phpml\Exception\InvalidArgumentException;
9
use Phpml\Helper\Predictable;
10
use Phpml\Helper\Trainable;
11
use Phpml\Math\Statistic\Mean;
12
13
class DecisionTree implements Classifier
14
{
15
    use Trainable;
16
    use Predictable;
17
18
    public const CONTINUOUS = 1;
19
20
    public const NOMINAL = 2;
21
22
    /**
23
     * @var int
24
     */
25
    public $actualDepth = 0;
26
27
    /**
28
     * @var array
29
     */
30
    protected $columnTypes = [];
31
32
    /**
33
     * @var DecisionTreeLeaf
34
     */
35
    protected $tree;
36
37
    /**
38
     * @var int
39
     */
40
    protected $maxDepth;
41
42
    /**
43
     * @var array
44
     */
45
    private $labels = [];
46
47
    /**
48
     * @var int
49
     */
50
    private $featureCount = 0;
51
52
    /**
53
     * @var int
54
     */
55
    private $numUsableFeatures = 0;
56
57
    /**
58
     * @var array
59
     */
60
    private $selectedFeatures = [];
61
62
    /**
63
     * @var array|null
64
     */
65
    private $featureImportances;
66
67
    /**
68
     * @var array
69
     */
70
    private $columnNames = [];
71
72
    public function __construct(int $maxDepth = 10)
73
    {
74
        $this->maxDepth = $maxDepth;
75
    }
76
77
    public function train(array $samples, array $targets): void
78
    {
79
        $this->samples = array_merge($this->samples, $samples);
80
        $this->targets = array_merge($this->targets, $targets);
81
82
        $this->featureCount = count($this->samples[0]);
83
        $this->columnTypes = self::getColumnTypes($this->samples);
84
        $this->labels = array_keys(array_count_values($this->targets));
85
        $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
86
87
        // Each time the tree is trained, feature importances are reset so that
88
        // we will have to compute it again depending on the new data
89
        $this->featureImportances = null;
90
91
        // If column names are given or computed before, then there is no
92
        // need to init it and accidentally remove the previous given names
93
        if ($this->columnNames === []) {
94
            $this->columnNames = range(0, $this->featureCount - 1);
95
        } elseif (count($this->columnNames) > $this->featureCount) {
96
            $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
97
        } elseif (count($this->columnNames) < $this->featureCount) {
98
            $this->columnNames = array_merge(
99
                $this->columnNames,
100
                range(count($this->columnNames), $this->featureCount - 1)
101
            );
102
        }
103
    }
104
105
    public static function getColumnTypes(array $samples): array
106
    {
107
        $types = [];
108
        $featureCount = count($samples[0]);
109
        for ($i = 0; $i < $featureCount; ++$i) {
110
            $values = array_column($samples, $i);
111
            $isCategorical = self::isCategoricalColumn($values);
112
            $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS;
113
        }
114
115
        return $types;
116
    }
117
118
    /**
119
     * @param mixed $baseValue
120
     */
121
    public function getGiniIndex($baseValue, array $colValues, array $targets): float
122
    {
123
        $countMatrix = [];
124
        foreach ($this->labels as $label) {
125
            $countMatrix[$label] = [0, 0];
126
        }
127
128
        foreach ($colValues as $index => $value) {
129
            $label = $targets[$index];
130
            $rowIndex = $value === $baseValue ? 0 : 1;
131
            ++$countMatrix[$label][$rowIndex];
132
        }
133
134
        $giniParts = [0, 0];
135
        for ($i = 0; $i <= 1; ++$i) {
136
            $part = 0;
137
            $sum = array_sum(array_column($countMatrix, $i));
138
            if ($sum > 0) {
139
                foreach ($this->labels as $label) {
140
                    $part += ($countMatrix[$label][$i] / (float) $sum) ** 2;
141
                }
142
            }
143
144
            $giniParts[$i] = (1 - $part) * $sum;
145
        }
146
147
        return array_sum($giniParts) / count($colValues);
148
    }
149
150
    /**
151
     * This method is used to set number of columns to be used
152
     * when deciding a split at an internal node of the tree.  <br>
153
     * If the value is given 0, then all features are used (default behaviour),
154
     * otherwise the given value will be used as a maximum for number of columns
155
     * randomly selected for each split operation.
156
     *
157
     * @return $this
158
     *
159
     * @throws InvalidArgumentException
160
     */
161
    public function setNumFeatures(int $numFeatures)
162
    {
163
        if ($numFeatures < 0) {
164
            throw new InvalidArgumentException('Selected column count should be greater or equal to zero');
165
        }
166
167
        $this->numUsableFeatures = $numFeatures;
168
169
        return $this;
170
    }
171
172
    /**
173
     * A string array to represent columns. Useful when HTML output or
174
     * column importances are desired to be inspected.
175
     *
176
     * @return $this
177
     *
178
     * @throws InvalidArgumentException
179
     */
180
    public function setColumnNames(array $names)
181
    {
182
        if ($this->featureCount !== 0 && count($names) !== $this->featureCount) {
183
            throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount));
184
        }
185
186
        $this->columnNames = $names;
187
188
        return $this;
189
    }
190
191
    public function getHtml(): string
192
    {
193
        return $this->tree->getHTML($this->columnNames);
194
    }
195
196
    /**
197
     * This will return an array including an importance value for
198
     * each column in the given dataset. The importance values are
199
     * normalized and their total makes 1.<br/>
200
     */
201
    public function getFeatureImportances(): array
202
    {
203
        if ($this->featureImportances !== null) {
204
            return $this->featureImportances;
205
        }
206
207
        $sampleCount = count($this->samples);
208
        $this->featureImportances = [];
209
        foreach ($this->columnNames as $column => $columnName) {
210
            $nodes = $this->getSplitNodesByColumn($column, $this->tree);
211
212
            $importance = 0;
213
            foreach ($nodes as $node) {
214
                $importance += $node->getNodeImpurityDecrease($sampleCount);
215
            }
216
217
            $this->featureImportances[$columnName] = $importance;
218
        }
219
220
        // Normalize & sort the importances
221
        $total = array_sum($this->featureImportances);
222
        if ($total > 0) {
223
            array_walk($this->featureImportances, function (&$importance) use ($total): void {
224
                $importance /= $total;
225
            });
226
            arsort($this->featureImportances);
227
        }
228
229
        return $this->featureImportances;
230
    }
231
232
    protected function getSplitLeaf(array $records, int $depth = 0): DecisionTreeLeaf
233
    {
234
        $split = $this->getBestSplit($records);
235
        $split->level = $depth;
236
        if ($this->actualDepth < $depth) {
237
            $this->actualDepth = $depth;
238
        }
239
240
        // Traverse all records to see if all records belong to the same class,
241
        // otherwise group the records so that we can classify the leaf
242
        // in case maximum depth is reached
243
        $leftRecords = [];
244
        $rightRecords = [];
245
        $remainingTargets = [];
246
        $prevRecord = null;
247
        $allSame = true;
248
249
        foreach ($records as $recordNo) {
250
            // Check if the previous record is the same with the current one
251
            $record = $this->samples[$recordNo];
252
            if ($prevRecord !== null && $prevRecord != $record) {
253
                $allSame = false;
254
            }
255
256
            $prevRecord = $record;
257
258
            // According to the split criteron, this record will
259
            // belong to either left or the right side in the next split
260
            if ($split->evaluate($record)) {
261
                $leftRecords[] = $recordNo;
262
            } else {
263
                $rightRecords[] = $recordNo;
264
            }
265
266
            // Group remaining targets
267
            $target = $this->targets[$recordNo];
268
            if (!array_key_exists($target, $remainingTargets)) {
269
                $remainingTargets[$target] = 1;
270
            } else {
271
                ++$remainingTargets[$target];
272
            }
273
        }
274
275
        if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) {
276
            $split->isTerminal = true;
277
            arsort($remainingTargets);
278
            $split->classValue = (string) key($remainingTargets);
279
        } else {
280
            if (isset($leftRecords[0])) {
281
                $split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);
282
            }
283
284
            if (isset($rightRecords[0])) {
285
                $split->rightLeaf = $this->getSplitLeaf($rightRecords, $depth + 1);
286
            }
287
        }
288
289
        return $split;
290
    }
291
292
    protected function getBestSplit(array $records): DecisionTreeLeaf
293
    {
294
        $targets = array_intersect_key($this->targets, array_flip($records));
295
        $samples = (array) array_combine(
296
            $records,
297
            $this->preprocess(array_intersect_key($this->samples, array_flip($records)))
298
        );
299
        $bestGiniVal = 1;
300
        $bestSplit = null;
301
        $features = $this->getSelectedFeatures();
302
        foreach ($features as $i) {
303
            $colValues = [];
304
            foreach ($samples as $index => $row) {
305
                $colValues[$index] = $row[$i];
306
            }
307
308
            $counts = array_count_values($colValues);
309
            arsort($counts);
310
            $baseValue = key($counts);
311
            if ($baseValue === null) {
312
                continue;
313
            }
314
315
            $gini = $this->getGiniIndex($baseValue, $colValues, $targets);
316
            if ($bestSplit === null || $bestGiniVal > $gini) {
317
                $split = new DecisionTreeLeaf();
318
                $split->value = $baseValue;
319
                $split->giniIndex = $gini;
320
                $split->columnIndex = $i;
321
                $split->isContinuous = $this->columnTypes[$i] === self::CONTINUOUS;
322
                $split->records = $records;
323
324
                // If a numeric column is to be selected, then
325
                // the original numeric value and the selected operator
326
                // will also be saved into the leaf for future access
327
                if ($this->columnTypes[$i] === self::CONTINUOUS) {
328
                    $matches = [];
329
                    preg_match("/^([<>=]{1,2})\s*(.*)/", (string) $split->value, $matches);
330
                    $split->operator = $matches[1];
331
                    $split->numericValue = (float) $matches[2];
332
                }
333
334
                $bestSplit = $split;
335
                $bestGiniVal = $gini;
336
            }
337
        }
338
339
        return $bestSplit;
0 ignored issues
show
Bug Best Practice introduced by
The expression return $bestSplit could return the type null which is incompatible with the type-hinted return Phpml\Classification\DecisionTree\DecisionTreeLeaf. Consider adding an additional type-check to rule them out.
Loading history...
340
    }
341
342
    /**
343
     * Returns available features/columns to the tree for the decision making
344
     * process. <br>
345
     *
346
     * If a number is given with setNumFeatures() method, then a random selection
347
     * of features up to this number is returned. <br>
348
     *
349
     * If some features are manually selected by use of setSelectedFeatures(),
350
     * then only these features are returned <br>
351
     *
352
     * If any of above methods were not called beforehand, then all features
353
     * are returned by default.
354
     */
355
    protected function getSelectedFeatures(): array
356
    {
357
        $allFeatures = range(0, $this->featureCount - 1);
358
        if ($this->numUsableFeatures === 0 && count($this->selectedFeatures) === 0) {
359
            return $allFeatures;
360
        }
361
362
        if (count($this->selectedFeatures) > 0) {
363
            return $this->selectedFeatures;
364
        }
365
366
        $numFeatures = $this->numUsableFeatures;
367
        if ($numFeatures > $this->featureCount) {
368
            $numFeatures = $this->featureCount;
369
        }
370
371
        shuffle($allFeatures);
372
        $selectedFeatures = array_slice($allFeatures, 0, $numFeatures);
373
        sort($selectedFeatures);
374
375
        return $selectedFeatures;
376
    }
377
378
    protected function preprocess(array $samples): array
379
    {
380
        // Detect and convert continuous data column values into
381
        // discrete values by using the median as a threshold value
382
        $columns = [];
383
        for ($i = 0; $i < $this->featureCount; ++$i) {
384
            $values = array_column($samples, $i);
385
            if ($this->columnTypes[$i] == self::CONTINUOUS) {
386
                $median = Mean::median($values);
387
                foreach ($values as &$value) {
388
                    if ($value <= $median) {
389
                        $value = "<= ${median}";
390
                    } else {
391
                        $value = "> ${median}";
392
                    }
393
                }
394
            }
395
396
            $columns[] = $values;
397
        }
398
399
        // Below method is a strange yet very simple & efficient method
400
        // to get the transpose of a 2D array
401
        return array_map(null, ...$columns);
402
    }
403
404
    protected static function isCategoricalColumn(array $columnValues): bool
405
    {
406
        $count = count($columnValues);
407
408
        // There are two main indicators that *may* show whether a
409
        // column is composed of discrete set of values:
410
        // 1- Column may contain string values and non-float values
411
        // 2- Number of unique values in the column is only a small fraction of
412
        //	  all values in that column (Lower than or equal to %20 of all values)
413
        $numericValues = array_filter($columnValues, 'is_numeric');
414
        $floatValues = array_filter($columnValues, 'is_float');
415
        if (count($floatValues) > 0) {
416
            return false;
417
        }
418
419
        if (count($numericValues) !== $count) {
420
            return true;
421
        }
422
423
        $distinctValues = array_count_values($columnValues);
424
425
        return count($distinctValues) <= $count / 5;
426
    }
427
428
    /**
429
     * Used to set predefined features to consider while deciding which column to use for a split
430
     */
431
    protected function setSelectedFeatures(array $selectedFeatures): void
432
    {
433
        $this->selectedFeatures = $selectedFeatures;
434
    }
435
436
    /**
437
     * Collects and returns an array of internal nodes that use the given
438
     * column as a split criterion
439
     */
440
    protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node): array
441
    {
442
        if ($node->isTerminal) {
443
            return [];
444
        }
445
446
        $nodes = [];
447
        if ($node->columnIndex === $column) {
448
            $nodes[] = $node;
449
        }
450
451
        $lNodes = [];
452
        $rNodes = [];
453
        if ($node->leftLeaf !== null) {
454
            $lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
455
        }
456
457
        if ($node->rightLeaf !== null) {
458
            $rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
459
        }
460
461
        return array_merge($nodes, $lNodes, $rNodes);
462
    }
463
464
    /**
465
     * @return mixed
466
     */
467
    protected function predictSample(array $sample)
468
    {
469
        $node = $this->tree;
470
        do {
471
            if ($node->isTerminal) {
472
                return $node->classValue;
473
            }
474
475
            if ($node->evaluate($sample)) {
476
                $node = $node->leftLeaf;
477
            } else {
478
                $node = $node->rightLeaf;
479
            }
480
        } while ($node);
481
482
        return $this->labels[0];
483
    }
484
}
485