Completed
Push — master ( 72b25f...1d7350 )
by Arkadiusz
04:38
created

RandomForest::__construct()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 6
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 6
rs 9.4285
c 0
b 0
f 0
cc 1
eloc 3
nc 1
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Ensemble;
6
7
use Phpml\Classification\Ensemble\Bagging;
8
use Phpml\Classification\DecisionTree;
9
use Phpml\Classification\NaiveBayes;
10
use Phpml\Classification\Classifier;
11
12
class RandomForest extends Bagging
13
{
14
    /**
15
     * @var float|string
16
     */
17
    protected $featureSubsetRatio = 'log';
18
19
    public function __construct($numClassifier = 50)
20
    {
21
        parent::__construct($numClassifier);
22
23
        $this->setSubsetRatio(1.0);
24
    }
25
26
    /**
27
     * This method is used to determine how much of the original columns (features)
28
     * will be used to construct subsets to train base classifiers.<br>
29
     *
30
     * Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0 <br>
31
     *
32
     * If there are many features that diminishes classification performance, then
33
     * small values should be preferred, otherwise, with low number of features,
34
     * default value (0.7) will result in satisfactory performance.
35
     *
36
     * @param mixed $ratio string or float should be given
37
     * @return $this
38
     * @throws Exception
39
     */
40
    public function setFeatureSubsetRatio($ratio)
41
    {
42
        if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) {
43
            throw new \Exception("When a float given, feature subset ratio should be between 0.1 and 1.0");
44
        }
45
        if (is_string($ratio) && $ratio != 'sqrt' && $ratio != 'log') {
46
            throw new \Exception("When a string given, feature subset ratio can only be 'sqrt' or 'log' ");
47
        }
48
        $this->featureSubsetRatio = $ratio;
49
        return $this;
50
    }
51
52
    /**
53
     * RandomForest algorithm is usable *only* with DecisionTree
54
     *
55
     * @param string $classifier
56
     * @param array $classifierOptions
57
     * @return $this
58
     */
59
    public function setClassifer(string $classifier, array $classifierOptions = [])
60
    {
61
        if ($classifier != DecisionTree::class) {
62
            throw new \Exception("RandomForest can only use DecisionTree as base classifier");
63
        }
64
65
        return parent::setClassifer($classifier, $classifierOptions);
66
    }
67
68
    /**
69
     * @param DecisionTree $classifier
70
     * @param int $index
71
     * @return DecisionTree
72
     */
73
    protected function initSingleClassifier($classifier, $index)
74
    {
75
        if (is_float($this->featureSubsetRatio)) {
76
            $featureCount = (int)($this->featureSubsetRatio * $this->featureCount);
77
        } elseif ($this->featureCount == 'sqrt') {
78
            $featureCount = (int)sqrt($this->featureCount) + 1;
79
        } else {
80
            $featureCount = (int)log($this->featureCount, 2) + 1;
81
        }
82
83
        if ($featureCount >= $this->featureCount) {
84
            $featureCount = $this->featureCount;
85
        }
86
87
        return $classifier->setNumFeatures($featureCount);
88
    }
89
}
90