Bagging::__construct()   A
last analyzed

Complexity

Conditions 1
Paths 1

Size

Total Lines 3
Code Lines 1

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nc 1
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Ensemble;
6
7
use Phpml\Classification\Classifier;
8
use Phpml\Classification\DecisionTree;
9
use Phpml\Exception\InvalidArgumentException;
10
use Phpml\Helper\Predictable;
11
use Phpml\Helper\Trainable;
12
use ReflectionClass;
13
14
class Bagging implements Classifier
15
{
16
    use Trainable;
17
    use Predictable;
18
19
    /**
20
     * @var int
21
     */
22
    protected $numSamples;
23
24
    /**
25
     * @var int
26
     */
27
    protected $featureCount = 0;
28
29
    /**
30
     * @var int
31
     */
32
    protected $numClassifier;
33
34
    /**
35
     * @var string
36
     */
37
    protected $classifier = DecisionTree::class;
38
39
    /**
40
     * @var array
41
     */
42
    protected $classifierOptions = ['depth' => 20];
43
44
    /**
45
     * @var array
46
     */
47
    protected $classifiers = [];
48
49
    /**
50
     * @var float
51
     */
52
    protected $subsetRatio = 0.7;
53
54
    /**
55
     * Creates an ensemble classifier with given number of base classifiers
56
     * Default number of base classifiers is 50.
57
     * The more number of base classifiers, the better performance but at the cost of procesing time
58
     */
59
    public function __construct(int $numClassifier = 50)
60
    {
61
        $this->numClassifier = $numClassifier;
62
    }
63
64
    /**
65
     * This method determines the ratio of samples used to create the 'bootstrap' subset,
66
     * e.g., random samples drawn from the original dataset with replacement (allow repeats),
67
     * to train each base classifier.
68
     *
69
     * @return $this
70
     *
71
     * @throws InvalidArgumentException
72
     */
73
    public function setSubsetRatio(float $ratio)
74
    {
75
        if ($ratio < 0.1 || $ratio > 1.0) {
76
            throw new InvalidArgumentException('Subset ratio should be between 0.1 and 1.0');
77
        }
78
79
        $this->subsetRatio = $ratio;
80
81
        return $this;
82
    }
83
84
    /**
85
     * This method is used to set the base classifier. Default value is
86
     * DecisionTree::class, but any class that implements the <i>Classifier</i>
87
     * can be used. <br>
88
     * While giving the parameters of the classifier, the values should be
89
     * given in the order they are in the constructor of the classifier and parameter
90
     * names are neglected.
91
     *
92
     * @return $this
93
     */
94
    public function setClassifer(string $classifier, array $classifierOptions = [])
95
    {
96
        $this->classifier = $classifier;
97
        $this->classifierOptions = $classifierOptions;
98
99
        return $this;
100
    }
101
102
    public function train(array $samples, array $targets): void
103
    {
104
        $this->samples = array_merge($this->samples, $samples);
105
        $this->targets = array_merge($this->targets, $targets);
106
        $this->featureCount = count($samples[0]);
107
        $this->numSamples = count($this->samples);
108
109
        // Init classifiers and train them with bootstrap samples
110
        $this->classifiers = $this->initClassifiers();
111
        $index = 0;
112
        foreach ($this->classifiers as $classifier) {
113
            [$samples, $targets] = $this->getRandomSubset($index);
114
            $classifier->train($samples, $targets);
115
            ++$index;
116
        }
117
    }
118
119
    protected function getRandomSubset(int $index): array
120
    {
121
        $samples = [];
122
        $targets = [];
123
        srand($index);
124
        $bootstrapSize = $this->subsetRatio * $this->numSamples;
125
        for ($i = 0; $i < $bootstrapSize; ++$i) {
126
            $rand = random_int(0, $this->numSamples - 1);
127
            $samples[] = $this->samples[$rand];
128
            $targets[] = $this->targets[$rand];
129
        }
130
131
        return [$samples, $targets];
132
    }
133
134
    protected function initClassifiers(): array
135
    {
136
        $classifiers = [];
137
        for ($i = 0; $i < $this->numClassifier; ++$i) {
138
            $ref = new ReflectionClass($this->classifier);
139
            /** @var Classifier $obj */
140
            $obj = count($this->classifierOptions) === 0 ? $ref->newInstance() : $ref->newInstanceArgs($this->classifierOptions);
141
142
            $classifiers[] = $this->initSingleClassifier($obj);
143
        }
144
145
        return $classifiers;
146
    }
147
148
    protected function initSingleClassifier(Classifier $classifier): Classifier
149
    {
150
        return $classifier;
151
    }
152
153
    /**
154
     * @return mixed
155
     */
156
    protected function predictSample(array $sample)
157
    {
158
        $predictions = [];
159
        foreach ($this->classifiers as $classifier) {
160
            /** @var Classifier $classifier */
161
            $predictions[] = $classifier->predict($sample);
162
        }
163
164
        $counts = array_count_values($predictions);
165
        arsort($counts);
166
        reset($counts);
167
168
        return key($counts);
169
    }
170
}
171