Test Setup Failed
Push — master ( 8544cf...91812f )
by Arkadiusz
02:16
created

DecisionTreeRegressor::train()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 11
rs 9.9
c 0
b 0
f 0
cc 1
nc 1
nop 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Regression;
6
7
use Phpml\Exception\InvalidOperationException;
8
use Phpml\Math\Statistic\Mean;
9
use Phpml\Math\Statistic\Variance;
10
use Phpml\Tree\CART;
11
use Phpml\Tree\Node\AverageNode;
12
use Phpml\Tree\Node\BinaryNode;
13
use Phpml\Tree\Node\DecisionNode;
14
15
final class DecisionTreeRegressor extends CART implements Regression
16
{
17
    /**
18
     * @var int|null
19
     */
20
    protected $maxFeatures;
21
22
    /**
23
     * @var float
24
     */
25
    protected $tolerance;
26
27
    /**
28
     * @var array
29
     */
30
    protected $columns = [];
31
32
    public function train(array $samples, array $targets): void
33
    {
34
        $features = count($samples[0]);
35
36
        $this->columns = range(0, $features - 1);
37
        $this->maxFeatures = $this->maxFeatures ?? (int) round(sqrt($features));
38
39
        $this->grow($samples, $targets);
40
41
        $this->columns = [];
42
    }
43
44
    public function predict(array $samples)
45
    {
46
        if ($this->bare()) {
47
            throw new InvalidOperationException('Regressor must be trained first');
48
        }
49
50
        $predictions = [];
51
52
        foreach ($samples as $sample) {
53
            $node = $this->search($sample);
0 ignored issues
show
Bug introduced by
Are you sure the assignment to $node is correct as $this->search($sample) (which targets Phpml\Tree\CART::search()) seems to always return null.

This check looks for function or method calls that always return null and whose return value is assigned to a variable.

class A
{
    function getObject()
    {
        return null;
    }

}

$a = new A();
$object = $a->getObject();

The method getObject() can return nothing but null, so it makes no sense to assign that value to a variable.

The reason is most likely that a function or method is imcomplete or has been reduced for debug purposes.

Loading history...
54
55
            $predictions[] = $node instanceof AverageNode
56
                ? $node->outcome()
57
                : null;
58
        }
59
60
        return $predictions;
61
    }
62
63
    protected function split(array $samples, array $targets): DecisionNode
64
    {
65
        $bestVariance = INF;
66
        $bestColumn = $bestValue = null;
67
        $bestGroups = [];
68
69
        shuffle($this->columns);
70
71
        foreach (array_slice($this->columns, 0, $this->maxFeatures) as $column) {
72
            $values = array_unique(array_column($samples, $column));
73
74
            foreach ($values as $value) {
75
                $groups = $this->partition($column, $value, $samples, $targets);
76
77
                $variance = $this->splitImpurity($groups);
78
79
                if ($variance < $bestVariance) {
80
                    $bestColumn = $column;
81
                    $bestValue = $value;
82
                    $bestGroups = $groups;
83
                    $bestVariance = $variance;
84
                }
85
86
                if ($variance <= $this->tolerance) {
87
                    break 2;
88
                }
89
            }
90
        }
91
92
        return new DecisionNode($bestColumn, $bestValue, $bestGroups, $bestVariance);
93
    }
94
95
    protected function terminate(array $targets): BinaryNode
96
    {
97
        return new AverageNode(Mean::arithmetic($targets), Variance::population($targets), count($targets));
98
    }
99
100
    protected function splitImpurity(array $groups): float
101
    {
102
        $samplesCount = (int) array_sum(array_map(static function (array $group) {
103
            return count($group[0]);
104
        }, $groups));
105
106
        $impurity = 0.;
107
108
        foreach ($groups as $group) {
109
            $k = count($group[1]);
110
111
            if ($k < 2) {
112
                continue 1;
113
            }
114
115
            $variance = Variance::population($group[1]);
116
117
            $impurity += ($k / $samplesCount) * $variance;
118
        }
119
120
        return $impurity;
121
    }
122
123
    /**
124
     * @param int|float $value
125
     */
126
    private function partition(int $column, $value, array $samples, array $targets): array
127
    {
128
        $leftSamples = $leftTargets = $rightSamples = $rightTargets = [];
129
        foreach ($samples as $index => $sample) {
130
            if ($sample[$column] < $value) {
131
                $leftSamples[] = $sample;
132
                $leftTargets[] = $targets[$index];
133
            } else {
134
                $rightSamples[] = $sample;
135
                $rightTargets[] = $targets[$index];
136
            }
137
        }
138
139
        return [
140
            [$leftSamples, $leftTargets],
141
            [$rightSamples, $rightTargets],
142
        ];
143
    }
144
}
145