Completed
Pull Request — master (#36)
by
unknown
02:40
created

RandomForest::setClassifer()   A

Complexity

Conditions 2
Paths 2

Size

Total Lines 8
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 8
rs 9.4285
c 0
b 0
f 0
cc 2
eloc 4
nc 2
nop 2
1
<?php
2
declare(strict_types=1);
3
4
namespace Phpml\Classification\Ensemble;
5
6
use Phpml\Classification\Ensemble\Bagging;
7
use Phpml\Classification\DecisionTree;
8
use Phpml\Classification\NaiveBayes;
9
use Phpml\Classification\Classifier;
10
11
class RandomForest extends Bagging
12
{
13
    /**
14
     * @var float
15
     */
16
    protected $subsetRatio = 1.0;
17
18
    /**
19
     * @var float|string
20
     */
21
    protected $featureSubsetRatio = 'log';
22
23
    /**
24
     * This method is used to determine how much of the original columns (features)
25
     * will be used to construct subsets to train base classifiers.<br>
26
     *
27
     * Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0 <br>
28
     *
29
     * If there are many features that diminishes classification performance, then
30
     * small values should be preferred, otherwise, with low number of features,
31
     * default value (0.7) will result in satisfactory performance.
32
     *
33
     * @param mixed $ratio string or float should be given
34
     * @return $this
35
     * @throws Exception
36
     */
37
    public function setFeatureSubsetRatio($ratio)
38
    {
39
        if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) {
40
            throw new \Exception("When a float given, feature subset ratio should be between 0.1 and 1.0");
41
        }
42
        if (is_string($ratio) && $ratio != 'sqrt' && $ratio != 'log') {
43
            throw new \Exception("When a string given, feature subset ratio can only be 'sqrt' or 'log' ");
44
        }
45
        $this->featureSubsetRatio = $ratio;
46
        return $this;
47
    }
48
49
    /**
50
     * RandomForest algorithm is usable *only* with DecisionTree
51
     *
52
     * @param string $classifier
53
     * @param array $classifierOptions
54
     * @return $this
55
     */
56
    public function setClassifer(string $classifier, array $classifierOptions = [])
57
    {
58
        if ($classifier != DecisionTree::class) {
59
            throw new \Exception("RandomForest can only use DecisionTree as base classifier");
60
        }
61
62
        return parent::setClassifer($classifier, $classifierOptions);
63
    }
64
65
    /**
66
     * @param DecisionTree $classifier
67
     * @param int $index
68
     * @return DecisionTree
69
     */
70
    protected function initSingleClassifier($classifier, $index)
71
    {
72
        if (is_float($this->featureSubsetRatio)) {
73
            $featureCount = (int)($this->featureSubsetRatio * $this->featureCount);
74
        } elseif ($this->featureCount == 'sqrt') {
75
            $featureCount = (int)sqrt($this->featureCount) + 1;
76
        } else {
77
            $featureCount = (int)log($this->featureCount, 2) + 1;
78
        }
79
80
        if ($featureCount >= $this->featureCount) {
81
            $featureCount = $this->featureCount;
82
        }
83
84
        return $classifier->setNumFeatures($featureCount);
85
    }
86
}
87