Test Setup Failed
Push — master ( 5e02b8...d3888e )
by Arkadiusz
11:41
created

RandomForest::initSingleClassifier()   B

Complexity

Conditions 6
Paths 13

Size

Total Lines 28

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 28
rs 8.8497
c 0
b 0
f 0
cc 6
nc 13
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Ensemble;
6
7
use Phpml\Classification\Classifier;
8
use Phpml\Classification\DecisionTree;
9
use Phpml\Exception\InvalidArgumentException;
10
11
class RandomForest extends Bagging
12
{
13
    /**
14
     * @var float|string
15
     */
16
    protected $featureSubsetRatio = 'log';
17
18
    /**
19
     * @var array|null
20
     */
21
    protected $columnNames;
22
23
    /**
24
     * Initializes RandomForest with the given number of trees. More trees
25
     * may increase the prediction performance while it will also substantially
26
     * increase the processing time and the required memory
27
     */
28
    public function __construct(int $numClassifier = 50)
29
    {
30
        parent::__construct($numClassifier);
31
32
        $this->setSubsetRatio(1.0);
33
    }
34
35
    /**
36
     * This method is used to determine how many of the original columns (features)
37
     * will be used to construct subsets to train base classifiers.<br>
38
     *
39
     * Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0 <br>
40
     *
41
     * Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
42
     * features to be taken into consideration while selecting subspace of features
43
     *
44
     * @param mixed $ratio
45
     */
46
    public function setFeatureSubsetRatio($ratio): self
47
    {
48
        if (!is_string($ratio) && !is_float($ratio)) {
49
            throw new InvalidArgumentException('Feature subset ratio must be a string or a float');
50
        }
51
52
        if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) {
53
            throw new InvalidArgumentException('When a float is given, feature subset ratio should be between 0.1 and 1.0');
54
        }
55
56
        if (is_string($ratio) && $ratio !== 'sqrt' && $ratio !== 'log') {
57
            throw new InvalidArgumentException("When a string is given, feature subset ratio can only be 'sqrt' or 'log'");
58
        }
59
60
        $this->featureSubsetRatio = $ratio;
61
62
        return $this;
63
    }
64
65
    /**
66
     * RandomForest algorithm is usable *only* with DecisionTree
67
     *
68
     * @return $this
69
     */
70
    public function setClassifer(string $classifier, array $classifierOptions = [])
71
    {
72
        if ($classifier !== DecisionTree::class) {
73
            throw new InvalidArgumentException('RandomForest can only use DecisionTree as base classifier');
74
        }
75
76
        parent::setClassifer($classifier, $classifierOptions);
77
78
        return $this;
79
    }
80
81
    /**
82
     * This will return an array including an importance value for
83
     * each column in the given dataset. Importance values for a column
84
     * is the average importance of that column in all trees in the forest
85
     */
86
    public function getFeatureImportances(): array
87
    {
88
        // Traverse each tree and sum importance of the columns
89
        $sum = [];
90
        foreach ($this->classifiers as $tree) {
91
            /** @var DecisionTree $tree */
92
            $importances = $tree->getFeatureImportances();
93
94
            foreach ($importances as $column => $importance) {
95
                if (array_key_exists($column, $sum)) {
96
                    $sum[$column] += $importance;
97
                } else {
98
                    $sum[$column] = $importance;
99
                }
100
            }
101
        }
102
103
        // Normalize & sort the importance values
104
        $total = array_sum($sum);
105
        array_walk($sum, function (&$importance) use ($total): void {
106
            $importance /= $total;
107
        });
108
        arsort($sum);
109
110
        return $sum;
111
    }
112
113
    /**
114
     * A string array to represent the columns is given. They are useful
115
     * when trying to print some information about the trees such as feature importances
116
     *
117
     * @return $this
118
     */
119
    public function setColumnNames(array $names)
120
    {
121
        $this->columnNames = $names;
122
123
        return $this;
124
    }
125
126
    /**
127
     * @return DecisionTree
128
     */
129
    protected function initSingleClassifier(Classifier $classifier): Classifier
130
    {
131
        if (!$classifier instanceof DecisionTree) {
132
            throw new InvalidArgumentException(
133
                sprintf('Classifier %s expected, got %s', DecisionTree::class, get_class($classifier))
134
            );
135
        }
136
137
        if (is_float($this->featureSubsetRatio)) {
138
            $featureCount = (int) ($this->featureSubsetRatio * $this->featureCount);
139
        } elseif ($this->featureSubsetRatio === 'sqrt') {
140
            $featureCount = (int) ($this->featureCount ** .5) + 1;
141
        } else {
142
            $featureCount = (int) log($this->featureCount, 2) + 1;
143
        }
144
145
        if ($featureCount >= $this->featureCount) {
146
            $featureCount = $this->featureCount;
147
        }
148
149
        if ($this->columnNames === null) {
150
            $this->columnNames = range(0, $this->featureCount - 1);
151
        }
152
153
        return $classifier
154
            ->setColumnNames($this->columnNames)
155
            ->setNumFeatures($featureCount);
156
    }
157
}
158