DecisionTreeRegressor::terminate()   A
last analyzed

Complexity

Conditions 1
Paths 1

Size

Total Lines 3
Code Lines 1

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
eloc 1
c 1
b 0
f 0
dl 0
loc 3
rs 10
cc 1
nc 1
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Regression;
6
7
use Phpml\Exception\InvalidArgumentException;
8
use Phpml\Exception\InvalidOperationException;
9
use Phpml\Math\Statistic\Mean;
10
use Phpml\Math\Statistic\Variance;
11
use Phpml\Tree\CART;
12
use Phpml\Tree\Node\AverageNode;
13
use Phpml\Tree\Node\BinaryNode;
14
use Phpml\Tree\Node\DecisionNode;
15
16
final class DecisionTreeRegressor extends CART implements Regression
17
{
18
    /**
19
     * @var int|null
20
     */
21
    protected $maxFeatures;
22
23
    /**
24
     * @var float
25
     */
26
    protected $tolerance;
27
28
    /**
29
     * @var array
30
     */
31
    protected $columns = [];
32
33
    public function __construct(
34
        int $maxDepth = PHP_INT_MAX,
35
        int $maxLeafSize = 3,
36
        float $minPurityIncrease = 0.,
37
        ?int $maxFeatures = null,
38
        float $tolerance = 1e-4
39
    ) {
40
        if ($maxFeatures !== null && $maxFeatures < 1) {
41
            throw new InvalidArgumentException('Max features must be greater than 0');
42
        }
43
44
        if ($tolerance < 0.) {
45
            throw new InvalidArgumentException('Tolerance must be equal or greater than 0');
46
        }
47
48
        $this->maxFeatures = $maxFeatures;
49
        $this->tolerance = $tolerance;
50
51
        parent::__construct($maxDepth, $maxLeafSize, $minPurityIncrease);
52
    }
53
54
    public function train(array $samples, array $targets): void
55
    {
56
        $features = count($samples[0]);
57
58
        $this->columns = range(0, $features - 1);
59
        $this->maxFeatures = $this->maxFeatures ?? (int) round(sqrt($features));
60
61
        $this->grow($samples, $targets);
62
63
        $this->columns = [];
64
    }
65
66
    public function predict(array $samples)
67
    {
68
        if ($this->bare()) {
69
            throw new InvalidOperationException('Regressor must be trained first');
70
        }
71
72
        $predictions = [];
73
74
        foreach ($samples as $sample) {
75
            $node = $this->search($sample);
76
77
            $predictions[] = $node instanceof AverageNode
78
                ? $node->outcome()
79
                : null;
80
        }
81
82
        return $predictions;
83
    }
84
85
    protected function split(array $samples, array $targets): DecisionNode
86
    {
87
        $bestVariance = INF;
88
        $bestColumn = $bestValue = null;
89
        $bestGroups = [];
90
91
        shuffle($this->columns);
92
93
        foreach (array_slice($this->columns, 0, $this->maxFeatures) as $column) {
94
            $values = array_unique(array_column($samples, $column));
95
96
            foreach ($values as $value) {
97
                $groups = $this->partition($column, $value, $samples, $targets);
98
99
                $variance = $this->splitImpurity($groups);
100
101
                if ($variance < $bestVariance) {
102
                    $bestColumn = $column;
103
                    $bestValue = $value;
104
                    $bestGroups = $groups;
105
                    $bestVariance = $variance;
106
                }
107
108
                if ($variance <= $this->tolerance) {
109
                    break 2;
110
                }
111
            }
112
        }
113
114
        return new DecisionNode($bestColumn, $bestValue, $bestGroups, $bestVariance);
115
    }
116
117
    protected function terminate(array $targets): BinaryNode
118
    {
119
        return new AverageNode(Mean::arithmetic($targets), Variance::population($targets), count($targets));
120
    }
121
122
    protected function splitImpurity(array $groups): float
123
    {
124
        $samplesCount = (int) array_sum(array_map(static function (array $group): int {
125
            return count($group[0]);
126
        }, $groups));
127
128
        $impurity = 0.;
129
130
        foreach ($groups as $group) {
131
            $k = count($group[1]);
132
133
            if ($k < 2) {
134
                continue 1;
135
            }
136
137
            $variance = Variance::population($group[1]);
138
139
            $impurity += ($k / $samplesCount) * $variance;
140
        }
141
142
        return $impurity;
143
    }
144
145
    /**
146
     * @param int|float $value
147
     */
148
    private function partition(int $column, $value, array $samples, array $targets): array
149
    {
150
        $leftSamples = $leftTargets = $rightSamples = $rightTargets = [];
151
        foreach ($samples as $index => $sample) {
152
            if ($sample[$column] < $value) {
153
                $leftSamples[] = $sample;
154
                $leftTargets[] = $targets[$index];
155
            } else {
156
                $rightSamples[] = $sample;
157
                $rightTargets[] = $targets[$index];
158
            }
159
        }
160
161
        return [
162
            [$leftSamples, $leftTargets],
163
            [$rightSamples, $rightTargets],
164
        ];
165
    }
166
}
167