Passed
Push — master ( e83f7b...d953ef )
by Arkadiusz
03:28
created

src/Phpml/Classification/DecisionTree.php (1 issue)

Check for implicit conversion of array to boolean.

Best Practice Bug Minor

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
8
use Phpml\Exception\InvalidArgumentException;
9
use Phpml\Helper\Predictable;
10
use Phpml\Helper\Trainable;
11
use Phpml\Math\Statistic\Mean;
12
13
class DecisionTree implements Classifier
14
{
15
    use Trainable, Predictable;
16
17
    public const CONTINUOUS = 1;
18
19
    public const NOMINAL = 2;
20
21
    /**
22
     * @var int
23
     */
24
    public $actualDepth = 0;
25
26
    /**
27
     * @var array
28
     */
29
    protected $columnTypes = [];
30
31
    /**
32
     * @var DecisionTreeLeaf
33
     */
34
    protected $tree = null;
35
36
    /**
37
     * @var int
38
     */
39
    protected $maxDepth;
40
41
    /**
42
     * @var array
43
     */
44
    private $labels = [];
45
46
    /**
47
     * @var int
48
     */
49
    private $featureCount = 0;
50
51
    /**
52
     * @var int
53
     */
54
    private $numUsableFeatures = 0;
55
56
    /**
57
     * @var array
58
     */
59
    private $selectedFeatures = [];
60
61
    /**
62
     * @var array|null
63
     */
64
    private $featureImportances;
65
66
    /**
67
     * @var array
68
     */
69
    private $columnNames = [];
70
71
    public function __construct(int $maxDepth = 10)
72
    {
73
        $this->maxDepth = $maxDepth;
74
    }
75
76
    public function train(array $samples, array $targets): void
77
    {
78
        $this->samples = array_merge($this->samples, $samples);
79
        $this->targets = array_merge($this->targets, $targets);
80
81
        $this->featureCount = count($this->samples[0]);
82
        $this->columnTypes = self::getColumnTypes($this->samples);
83
        $this->labels = array_keys(array_count_values($this->targets));
84
        $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
85
86
        // Each time the tree is trained, feature importances are reset so that
87
        // we will have to compute it again depending on the new data
88
        $this->featureImportances = null;
89
90
        // If column names are given or computed before, then there is no
91
        // need to init it and accidentally remove the previous given names
92
        if ($this->columnNames === []) {
93
            $this->columnNames = range(0, $this->featureCount - 1);
94
        } elseif (count($this->columnNames) > $this->featureCount) {
95
            $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
96
        } elseif (count($this->columnNames) < $this->featureCount) {
97
            $this->columnNames = array_merge(
98
                $this->columnNames,
99
                range(count($this->columnNames), $this->featureCount - 1)
100
            );
101
        }
102
    }
103
104
    public static function getColumnTypes(array $samples): array
105
    {
106
        $types = [];
107
        $featureCount = count($samples[0]);
108
        for ($i = 0; $i < $featureCount; ++$i) {
109
            $values = array_column($samples, $i);
110
            $isCategorical = self::isCategoricalColumn($values);
111
            $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS;
112
        }
113
114
        return $types;
115
    }
116
117
    /**
118
     * @param mixed $baseValue
119
     */
120
    public function getGiniIndex($baseValue, array $colValues, array $targets): float
121
    {
122
        $countMatrix = [];
123
        foreach ($this->labels as $label) {
124
            $countMatrix[$label] = [0, 0];
125
        }
126
127
        foreach ($colValues as $index => $value) {
128
            $label = $targets[$index];
129
            $rowIndex = $value === $baseValue ? 0 : 1;
130
            ++$countMatrix[$label][$rowIndex];
131
        }
132
133
        $giniParts = [0, 0];
134
        for ($i = 0; $i <= 1; ++$i) {
135
            $part = 0;
136
            $sum = array_sum(array_column($countMatrix, $i));
137
            if ($sum > 0) {
138
                foreach ($this->labels as $label) {
139
                    $part += pow($countMatrix[$label][$i] / (float) $sum, 2);
140
                }
141
            }
142
143
            $giniParts[$i] = (1 - $part) * $sum;
144
        }
145
146
        return array_sum($giniParts) / count($colValues);
147
    }
148
149
    /**
150
     * This method is used to set number of columns to be used
151
     * when deciding a split at an internal node of the tree.  <br>
152
     * If the value is given 0, then all features are used (default behaviour),
153
     * otherwise the given value will be used as a maximum for number of columns
154
     * randomly selected for each split operation.
155
     *
156
     * @return $this
157
     *
158
     * @throws InvalidArgumentException
159
     */
160
    public function setNumFeatures(int $numFeatures)
161
    {
162
        if ($numFeatures < 0) {
163
            throw new InvalidArgumentException('Selected column count should be greater or equal to zero');
164
        }
165
166
        $this->numUsableFeatures = $numFeatures;
167
168
        return $this;
169
    }
170
171
    /**
172
     * A string array to represent columns. Useful when HTML output or
173
     * column importances are desired to be inspected.
174
     *
175
     * @return $this
176
     *
177
     * @throws InvalidArgumentException
178
     */
179
    public function setColumnNames(array $names)
180
    {
181
        if ($this->featureCount !== 0 && count($names) !== $this->featureCount) {
182
            throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount));
183
        }
184
185
        $this->columnNames = $names;
186
187
        return $this;
188
    }
189
190
    public function getHtml(): string
191
    {
192
        return $this->tree->getHTML($this->columnNames);
193
    }
194
195
    /**
196
     * This will return an array including an importance value for
197
     * each column in the given dataset. The importance values are
198
     * normalized and their total makes 1.<br/>
199
     */
200
    public function getFeatureImportances(): array
201
    {
202
        if ($this->featureImportances !== null) {
203
            return $this->featureImportances;
204
        }
205
206
        $sampleCount = count($this->samples);
207
        $this->featureImportances = [];
208
        foreach ($this->columnNames as $column => $columnName) {
209
            $nodes = $this->getSplitNodesByColumn($column, $this->tree);
210
211
            $importance = 0;
212
            foreach ($nodes as $node) {
213
                $importance += $node->getNodeImpurityDecrease($sampleCount);
214
            }
215
216
            $this->featureImportances[$columnName] = $importance;
217
        }
218
219
        // Normalize & sort the importances
220
        $total = array_sum($this->featureImportances);
221
        if ($total > 0) {
222
            foreach ($this->featureImportances as &$importance) {
223
                $importance /= $total;
224
            }
225
226
            arsort($this->featureImportances);
227
        }
228
229
        return $this->featureImportances;
230
    }
231
232
    protected function getSplitLeaf(array $records, int $depth = 0): DecisionTreeLeaf
233
    {
234
        $split = $this->getBestSplit($records);
235
        $split->level = $depth;
236
        if ($this->actualDepth < $depth) {
237
            $this->actualDepth = $depth;
238
        }
239
240
        // Traverse all records to see if all records belong to the same class,
241
        // otherwise group the records so that we can classify the leaf
242
        // in case maximum depth is reached
243
        $leftRecords = [];
244
        $rightRecords = [];
245
        $remainingTargets = [];
246
        $prevRecord = null;
247
        $allSame = true;
248
249
        foreach ($records as $recordNo) {
250
            // Check if the previous record is the same with the current one
251
            $record = $this->samples[$recordNo];
252
            if ($prevRecord && $prevRecord != $record) {
253
                $allSame = false;
254
            }
255
256
            $prevRecord = $record;
257
258
            // According to the split criteron, this record will
259
            // belong to either left or the right side in the next split
260
            if ($split->evaluate($record)) {
261
                $leftRecords[] = $recordNo;
262
            } else {
263
                $rightRecords[] = $recordNo;
264
            }
265
266
            // Group remaining targets
267
            $target = $this->targets[$recordNo];
268
            if (!array_key_exists($target, $remainingTargets)) {
269
                $remainingTargets[$target] = 1;
270
            } else {
271
                ++$remainingTargets[$target];
272
            }
273
        }
274
275
        if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) {
276
            $split->isTerminal = 1;
277
            arsort($remainingTargets);
278
            $split->classValue = key($remainingTargets);
279
        } else {
280
            if ($leftRecords) {
281
                $split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);
282
            }
283
284
            if ($rightRecords) {
285
                $split->rightLeaf = $this->getSplitLeaf($rightRecords, $depth + 1);
286
            }
287
        }
288
289
        return $split;
290
    }
291
292
    protected function getBestSplit(array $records): DecisionTreeLeaf
293
    {
294
        $targets = array_intersect_key($this->targets, array_flip($records));
295
        $samples = array_intersect_key($this->samples, array_flip($records));
296
        $samples = array_combine($records, $this->preprocess($samples));
297
        $bestGiniVal = 1;
298
        $bestSplit = null;
299
        $features = $this->getSelectedFeatures();
300
        foreach ($features as $i) {
301
            $colValues = [];
302
            foreach ($samples as $index => $row) {
303
                $colValues[$index] = $row[$i];
304
            }
305
306
            $counts = array_count_values($colValues);
307
            arsort($counts);
308
            $baseValue = key($counts);
309
            $gini = $this->getGiniIndex($baseValue, $colValues, $targets);
310
            if ($bestSplit === null || $bestGiniVal > $gini) {
311
                $split = new DecisionTreeLeaf();
312
                $split->value = $baseValue;
313
                $split->giniIndex = $gini;
314
                $split->columnIndex = $i;
315
                $split->isContinuous = $this->columnTypes[$i] == self::CONTINUOUS;
316
                $split->records = $records;
317
318
                // If a numeric column is to be selected, then
319
                // the original numeric value and the selected operator
320
                // will also be saved into the leaf for future access
321
                if ($this->columnTypes[$i] == self::CONTINUOUS) {
322
                    $matches = [];
323
                    preg_match("/^([<>=]{1,2})\s*(.*)/", (string) $split->value, $matches);
324
                    $split->operator = $matches[1];
325
                    $split->numericValue = (float) $matches[2];
326
                }
327
328
                $bestSplit = $split;
329
                $bestGiniVal = $gini;
330
            }
331
        }
332
333
        return $bestSplit;
334
    }
335
336
    /**
337
     * Returns available features/columns to the tree for the decision making
338
     * process. <br>
339
     *
340
     * If a number is given with setNumFeatures() method, then a random selection
341
     * of features up to this number is returned. <br>
342
     *
343
     * If some features are manually selected by use of setSelectedFeatures(),
344
     * then only these features are returned <br>
345
     *
346
     * If any of above methods were not called beforehand, then all features
347
     * are returned by default.
348
     */
349
    protected function getSelectedFeatures(): array
350
    {
351
        $allFeatures = range(0, $this->featureCount - 1);
352
        if ($this->numUsableFeatures === 0 && !$this->selectedFeatures) {
0 ignored issues
show
Bug Best Practice introduced by Arkadiusz Kondas
The expression $this->selectedFeatures of type array is implicitly converted to a boolean; are you sure this is intended? If so, consider using empty($expr) instead to make it clear that you intend to check for an array without elements.

This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.

Consider making the comparison explicit by using empty(..) or ! empty(...) instead.

Loading history...
353
            return $allFeatures;
354
        }
355
356
        if ($this->selectedFeatures) {
357
            return $this->selectedFeatures;
358
        }
359
360
        $numFeatures = $this->numUsableFeatures;
361
        if ($numFeatures > $this->featureCount) {
362
            $numFeatures = $this->featureCount;
363
        }
364
365
        shuffle($allFeatures);
366
        $selectedFeatures = array_slice($allFeatures, 0, $numFeatures, false);
367
        sort($selectedFeatures);
368
369
        return $selectedFeatures;
370
    }
371
372
    protected function preprocess(array $samples): array
373
    {
374
        // Detect and convert continuous data column values into
375
        // discrete values by using the median as a threshold value
376
        $columns = [];
377
        for ($i = 0; $i < $this->featureCount; ++$i) {
378
            $values = array_column($samples, $i);
379
            if ($this->columnTypes[$i] == self::CONTINUOUS) {
380
                $median = Mean::median($values);
381
                foreach ($values as &$value) {
382
                    if ($value <= $median) {
383
                        $value = "<= ${median}";
384
                    } else {
385
                        $value = "> ${median}";
386
                    }
387
                }
388
            }
389
390
            $columns[] = $values;
391
        }
392
393
        // Below method is a strange yet very simple & efficient method
394
        // to get the transpose of a 2D array
395
        return array_map(null, ...$columns);
396
    }
397
398
    protected static function isCategoricalColumn(array $columnValues): bool
399
    {
400
        $count = count($columnValues);
401
402
        // There are two main indicators that *may* show whether a
403
        // column is composed of discrete set of values:
404
        // 1- Column may contain string values and non-float values
405
        // 2- Number of unique values in the column is only a small fraction of
406
        //	  all values in that column (Lower than or equal to %20 of all values)
407
        $numericValues = array_filter($columnValues, 'is_numeric');
408
        $floatValues = array_filter($columnValues, 'is_float');
409
        if ($floatValues) {
410
            return false;
411
        }
412
413
        if (count($numericValues) !== $count) {
414
            return true;
415
        }
416
417
        $distinctValues = array_count_values($columnValues);
418
419
        return count($distinctValues) <= $count / 5;
420
    }
421
422
    /**
423
     * Used to set predefined features to consider while deciding which column to use for a split
424
     */
425
    protected function setSelectedFeatures(array $selectedFeatures): void
426
    {
427
        $this->selectedFeatures = $selectedFeatures;
428
    }
429
430
    /**
431
     * Collects and returns an array of internal nodes that use the given
432
     * column as a split criterion
433
     */
434
    protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node): array
435
    {
436
        if (!$node || $node->isTerminal) {
437
            return [];
438
        }
439
440
        $nodes = [];
441
        if ($node->columnIndex === $column) {
442
            $nodes[] = $node;
443
        }
444
445
        $lNodes = [];
446
        $rNodes = [];
447
        if ($node->leftLeaf) {
448
            $lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
449
        }
450
451
        if ($node->rightLeaf) {
452
            $rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
453
        }
454
455
        $nodes = array_merge($nodes, $lNodes, $rNodes);
456
457
        return $nodes;
458
    }
459
460
    /**
461
     * @return mixed
462
     */
463
    protected function predictSample(array $sample)
464
    {
465
        $node = $this->tree;
466
        do {
467
            if ($node->isTerminal) {
468
                break;
469
            }
470
471
            if ($node->evaluate($sample)) {
472
                $node = $node->leftLeaf;
473
            } else {
474
                $node = $node->rightLeaf;
475
            }
476
        } while ($node);
477
478
        return $node ? $node->classValue : $this->labels[0];
479
    }
480
}
481