Passed
Push — master ( e83f7b...d953ef )
by Arkadiusz
03:28
created

RandomForest::getFeatureImportances()   B

Complexity

Conditions 5
Paths 8

Size

Total Lines 27
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 27
rs 8.439
c 0
b 0
f 0
cc 5
eloc 14
nc 8
nop 0
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Ensemble;
6
7
use Exception;
8
use Phpml\Classification\Classifier;
9
use Phpml\Classification\DecisionTree;
10
11
class RandomForest extends Bagging
12
{
13
    /**
14
     * @var float|string
15
     */
16
    protected $featureSubsetRatio = 'log';
17
18
    /**
19
     * @var array
20
     */
21
    protected $columnNames = null;
22
23
    /**
24
     * Initializes RandomForest with the given number of trees. More trees
25
     * may increase the prediction performance while it will also substantially
26
     * increase the processing time and the required memory
27
     */
28
    public function __construct(int $numClassifier = 50)
29
    {
30
        parent::__construct($numClassifier);
31
32
        $this->setSubsetRatio(1.0);
33
    }
34
35
    /**
36
     * This method is used to determine how many of the original columns (features)
37
     * will be used to construct subsets to train base classifiers.<br>
38
     *
39
     * Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0 <br>
40
     *
41
     * Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
42
     * features to be taken into consideration while selecting subspace of features
43
     *
44
     * @param mixed $ratio string or float should be given
45
     *
46
     * @return $this
47
     *
48
     * @throws \Exception
49
     */
50
    public function setFeatureSubsetRatio($ratio)
51
    {
52
        if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) {
53
            throw new Exception('When a float given, feature subset ratio should be between 0.1 and 1.0');
54
        }
55
56
        if (is_string($ratio) && $ratio != 'sqrt' && $ratio != 'log') {
57
            throw new Exception("When a string given, feature subset ratio can only be 'sqrt' or 'log' ");
58
        }
59
60
        $this->featureSubsetRatio = $ratio;
61
62
        return $this;
63
    }
64
65
    /**
66
     * RandomForest algorithm is usable *only* with DecisionTree
67
     *
68
     * @return $this
69
     *
70
     * @throws \Exception
71
     */
72
    public function setClassifer(string $classifier, array $classifierOptions = [])
73
    {
74
        if ($classifier != DecisionTree::class) {
75
            throw new Exception('RandomForest can only use DecisionTree as base classifier');
76
        }
77
78
        return parent::setClassifer($classifier, $classifierOptions);
79
    }
80
81
    /**
82
     * This will return an array including an importance value for
83
     * each column in the given dataset. Importance values for a column
84
     * is the average importance of that column in all trees in the forest
85
     */
86
    public function getFeatureImportances(): array
87
    {
88
        // Traverse each tree and sum importance of the columns
89
        $sum = [];
90
        foreach ($this->classifiers as $tree) {
91
            /* @var $tree DecisionTree */
92
            $importances = $tree->getFeatureImportances();
93
94
            foreach ($importances as $column => $importance) {
95
                if (array_key_exists($column, $sum)) {
96
                    $sum[$column] += $importance;
97
                } else {
98
                    $sum[$column] = $importance;
99
                }
100
            }
101
        }
102
103
        // Normalize & sort the importance values
104
        $total = array_sum($sum);
105
        foreach ($sum as &$importance) {
106
            $importance /= $total;
107
        }
108
109
        arsort($sum);
110
111
        return $sum;
112
    }
113
114
    /**
115
     * A string array to represent the columns is given. They are useful
116
     * when trying to print some information about the trees such as feature importances
117
     *
118
     * @return $this
119
     */
120
    public function setColumnNames(array $names)
121
    {
122
        $this->columnNames = $names;
123
124
        return $this;
125
    }
126
127
    /**
128
     * @param DecisionTree $classifier
129
     *
130
     * @return DecisionTree
131
     */
132
    protected function initSingleClassifier(Classifier $classifier): Classifier
133
    {
134
        if (is_float($this->featureSubsetRatio)) {
135
            $featureCount = (int) ($this->featureSubsetRatio * $this->featureCount);
136
        } elseif ($this->featureCount == 'sqrt') {
137
            $featureCount = (int) sqrt($this->featureCount) + 1;
138
        } else {
139
            $featureCount = (int) log($this->featureCount, 2) + 1;
140
        }
141
142
        if ($featureCount >= $this->featureCount) {
143
            $featureCount = $this->featureCount;
144
        }
145
146
        if ($this->columnNames === null) {
147
            $this->columnNames = range(0, $this->featureCount - 1);
148
        }
149
150
        return $classifier
0 ignored issues
show
Bug introduced by
It seems like you code against a concrete implementation and not the interface Phpml\Classification\Classifier as the method setColumnNames() does only exist in the following implementations of said interface: Phpml\Classification\DecisionTree, Phpml\Classification\Ensemble\RandomForest.

Let’s take a look at an example:

interface User
{
    /** @return string */
    public function getPassword();
}

class MyUser implements User
{
    public function getPassword()
    {
        // return something
    }

    public function getDisplayName()
    {
        // return some name.
    }
}

class AuthSystem
{
    public function authenticate(User $user)
    {
        $this->logger->info(sprintf('Authenticating %s.', $user->getDisplayName()));
        // do something.
    }
}

In the above example, the authenticate() method works fine as long as you just pass instances of MyUser. However, if you now also want to pass a different implementation of User which does not have a getDisplayName() method, the code will break.

Available Fixes

  1. Change the type-hint for the parameter:

    class AuthSystem
    {
        public function authenticate(MyUser $user) { /* ... */ }
    }
    
  2. Add an additional type-check:

    class AuthSystem
    {
        public function authenticate(User $user)
        {
            if ($user instanceof MyUser) {
                $this->logger->info(/** ... */);
            }
    
            // or alternatively
            if ( ! $user instanceof MyUser) {
                throw new \LogicException(
                    '$user must be an instance of MyUser, '
                   .'other instances are not supported.'
                );
            }
    
        }
    }
    
Note: PHP Analyzer uses reverse abstract interpretation to narrow down the types inside the if block in such a case.
  1. Add the method to the interface:

    interface User
    {
        /** @return string */
        public function getPassword();
    
        /** @return string */
        public function getDisplayName();
    }
    
Loading history...
151
            ->setColumnNames($this->columnNames)
152
            ->setNumFeatures($featureCount);
153
    }
154
}
155