Completed
Push — master ( 240a22...a33d5f )
by Arkadiusz
02:53
created

RandomForest::getFeatureImportances()   B

Complexity

Conditions 5
Paths 8

Size

Total Lines 27
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 27
rs 8.439
c 0
b 0
f 0
cc 5
eloc 14
nc 8
nop 0
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Ensemble;
6
7
use Phpml\Classification\Ensemble\Bagging;
8
use Phpml\Classification\DecisionTree;
9
use Phpml\Classification\NaiveBayes;
10
use Phpml\Classification\Classifier;
11
12
class RandomForest extends Bagging
13
{
14
    /**
15
     * @var float|string
16
     */
17
    protected $featureSubsetRatio = 'log';
18
19
    /**
20
     * @var array
21
     */
22
    protected $columnNames = null;
23
24
    /**
25
     * Initializes RandomForest with the given number of trees. More trees
26
     * may increase the prediction performance while it will also substantially
27
     * increase the processing time and the required memory
28
     *
29
     * @param type $numClassifier
30
     */
31
    public function __construct($numClassifier = 50)
32
    {
33
        parent::__construct($numClassifier);
34
35
        $this->setSubsetRatio(1.0);
36
    }
37
38
    /**
39
     * This method is used to determine how many of the original columns (features)
40
     * will be used to construct subsets to train base classifiers.<br>
41
     *
42
     * Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0 <br>
43
     *
44
     * Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
45
     * features to be taken into consideration while selecting subspace of features
46
     *
47
     * @param mixed $ratio string or float should be given
48
     * @return $this
49
     * @throws Exception
50
     */
51
    public function setFeatureSubsetRatio($ratio)
52
    {
53
        if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) {
54
            throw new \Exception("When a float given, feature subset ratio should be between 0.1 and 1.0");
55
        }
56
        if (is_string($ratio) && $ratio != 'sqrt' && $ratio != 'log') {
57
            throw new \Exception("When a string given, feature subset ratio can only be 'sqrt' or 'log' ");
58
        }
59
        $this->featureSubsetRatio = $ratio;
60
        return $this;
61
    }
62
63
    /**
64
     * RandomForest algorithm is usable *only* with DecisionTree
65
     *
66
     * @param string $classifier
67
     * @param array $classifierOptions
68
     * @return $this
69
     */
70
    public function setClassifer(string $classifier, array $classifierOptions = [])
71
    {
72
        if ($classifier != DecisionTree::class) {
73
            throw new \Exception("RandomForest can only use DecisionTree as base classifier");
74
        }
75
76
        return parent::setClassifer($classifier, $classifierOptions);
77
    }
78
79
    /**
80
     * This will return an array including an importance value for
81
     * each column in the given dataset. Importance values for a column
82
     * is the average importance of that column in all trees in the forest
83
     *
84
     * @return array
85
     */
86
    public function getFeatureImportances()
87
    {
88
        // Traverse each tree and sum importance of the columns
89
        $sum = [];
90
        foreach ($this->classifiers as $tree) {
91
            /* @var $tree DecisionTree */
92
            $importances = $tree->getFeatureImportances();
93
94
            foreach ($importances as $column => $importance) {
95
                if (array_key_exists($column, $sum)) {
96
                    $sum[$column] += $importance;
97
                } else {
98
                    $sum[$column] = $importance;
99
                }
100
            }
101
        }
102
103
        // Normalize & sort the importance values
104
        $total = array_sum($sum);
105
        foreach ($sum as &$importance) {
106
            $importance /= $total;
107
        }
108
109
        arsort($sum);
110
111
        return $sum;
112
    }
113
114
    /**
115
     * A string array to represent the columns is given. They are useful
116
     * when trying to print some information about the trees such as feature importances
117
     *
118
     * @param array $names
119
     * @return $this
120
     */
121
    public function setColumnNames(array $names)
122
    {
123
        $this->columnNames = $names;
124
125
        return $this;
126
    }
127
128
    /**
129
     * @param DecisionTree $classifier
130
     * @param int $index
131
     * @return DecisionTree
132
     */
133
    protected function initSingleClassifier($classifier, $index)
134
    {
135
        if (is_float($this->featureSubsetRatio)) {
136
            $featureCount = (int)($this->featureSubsetRatio * $this->featureCount);
137
        } elseif ($this->featureCount == 'sqrt') {
138
            $featureCount = (int)sqrt($this->featureCount) + 1;
139
        } else {
140
            $featureCount = (int)log($this->featureCount, 2) + 1;
141
        }
142
143
        if ($featureCount >= $this->featureCount) {
144
            $featureCount = $this->featureCount;
145
        }
146
147
        if ($this->columnNames === null) {
148
            $this->columnNames = range(0, $this->featureCount - 1);
149
        }
150
151
        return $classifier
152
                ->setColumnNames($this->columnNames)
153
                ->setNumFeatures($featureCount);
154
    }
155
}
156