1 | <?php |
||
2 | |||
3 | declare(strict_types=1); |
||
4 | |||
5 | namespace Phpml\Tree; |
||
6 | |||
7 | use Phpml\Exception\InvalidArgumentException; |
||
8 | use Phpml\Tree\Node\BinaryNode; |
||
9 | use Phpml\Tree\Node\DecisionNode; |
||
10 | use Phpml\Tree\Node\LeafNode; |
||
11 | |||
12 | abstract class CART |
||
13 | { |
||
14 | /** |
||
15 | * @var DecisionNode|null |
||
16 | */ |
||
17 | protected $root; |
||
18 | |||
19 | /** |
||
20 | * @var int |
||
21 | */ |
||
22 | protected $maxDepth; |
||
23 | |||
24 | /** |
||
25 | * @var int |
||
26 | */ |
||
27 | protected $maxLeafSize; |
||
28 | |||
29 | /** |
||
30 | * @var float |
||
31 | */ |
||
32 | protected $minPurityIncrease; |
||
33 | |||
34 | /** |
||
35 | * @var int |
||
36 | */ |
||
37 | protected $featureCount; |
||
38 | |||
39 | public function __construct(int $maxDepth = PHP_INT_MAX, int $maxLeafSize = 3, float $minPurityIncrease = 0.) |
||
40 | { |
||
41 | if ($maxDepth < 1) { |
||
42 | throw new InvalidArgumentException('Max depth must be greater than 0'); |
||
43 | } |
||
44 | |||
45 | if ($maxLeafSize < 1) { |
||
46 | throw new InvalidArgumentException('Max leaf size must be greater than 0'); |
||
47 | } |
||
48 | |||
49 | if ($minPurityIncrease < 0.) { |
||
50 | throw new InvalidArgumentException('Min purity increase must be equal or greater than 0'); |
||
51 | } |
||
52 | |||
53 | $this->maxDepth = $maxDepth; |
||
54 | $this->maxLeafSize = $maxLeafSize; |
||
55 | $this->minPurityIncrease = $minPurityIncrease; |
||
56 | } |
||
57 | |||
58 | public function root(): ?DecisionNode |
||
59 | { |
||
60 | return $this->root; |
||
61 | } |
||
62 | |||
63 | public function height(): int |
||
64 | { |
||
65 | return $this->root !== null ? $this->root->height() : 0; |
||
66 | } |
||
67 | |||
68 | public function balance(): int |
||
69 | { |
||
70 | return $this->root !== null ? $this->root->balance() : 0; |
||
71 | } |
||
72 | |||
73 | public function bare(): bool |
||
74 | { |
||
75 | return $this->root === null; |
||
76 | } |
||
77 | |||
78 | public function grow(array $samples, array $targets): void |
||
79 | { |
||
80 | $this->featureCount = count($samples[0]); |
||
81 | $depth = 1; |
||
82 | $this->root = $this->split($samples, $targets); |
||
83 | $stack = [[$this->root, $depth]]; |
||
84 | |||
85 | while ($stack) { |
||
0 ignored issues
–
show
|
|||
86 | [$current, $depth] = array_pop($stack) ?? []; |
||
87 | |||
88 | [$left, $right] = $current->groups(); |
||
89 | |||
90 | $current->cleanup(); |
||
91 | |||
92 | $depth++; |
||
93 | |||
94 | if ($left[1] === [] || $right[1] === []) { |
||
95 | $node = $this->terminate(array_merge($left[1], $right[1])); |
||
96 | |||
97 | $current->attachLeft($node); |
||
98 | $current->attachRight($node); |
||
99 | |||
100 | continue 1; |
||
101 | } |
||
102 | |||
103 | if ($depth >= $this->maxDepth) { |
||
104 | $current->attachLeft($this->terminate($left[1])); |
||
105 | $current->attachRight($this->terminate($right[1])); |
||
106 | |||
107 | continue 1; |
||
108 | } |
||
109 | |||
110 | if (count($left[1]) > $this->maxLeafSize) { |
||
111 | $node = $this->split($left[0], $left[1]); |
||
112 | |||
113 | if ($node->purityIncrease() + 1e-8 > $this->minPurityIncrease) { |
||
114 | $current->attachLeft($node); |
||
115 | |||
116 | $stack[] = [$node, $depth]; |
||
117 | } else { |
||
118 | $current->attachLeft($this->terminate($left[1])); |
||
119 | } |
||
120 | } else { |
||
121 | $current->attachLeft($this->terminate($left[1])); |
||
122 | } |
||
123 | |||
124 | if (count($right[1]) > $this->maxLeafSize) { |
||
125 | $node = $this->split($right[0], $right[1]); |
||
126 | |||
127 | if ($node->purityIncrease() + 1e-8 > $this->minPurityIncrease) { |
||
128 | $current->attachRight($node); |
||
129 | |||
130 | $stack[] = [$node, $depth]; |
||
131 | } else { |
||
132 | $current->attachRight($this->terminate($right[1])); |
||
133 | } |
||
134 | } else { |
||
135 | $current->attachRight($this->terminate($right[1])); |
||
136 | } |
||
137 | } |
||
138 | } |
||
139 | |||
140 | public function search(array $sample): ?BinaryNode |
||
141 | { |
||
142 | $current = $this->root; |
||
143 | |||
144 | while ($current) { |
||
145 | if ($current instanceof DecisionNode) { |
||
146 | $value = $current->value(); |
||
147 | |||
148 | if (is_string($value)) { |
||
149 | if ($sample[$current->column()] === $value) { |
||
150 | $current = $current->left(); |
||
151 | } else { |
||
152 | $current = $current->right(); |
||
153 | } |
||
154 | } else { |
||
155 | if ($sample[$current->column()] < $value) { |
||
156 | $current = $current->left(); |
||
157 | } else { |
||
158 | $current = $current->right(); |
||
159 | } |
||
160 | } |
||
161 | |||
162 | continue 1; |
||
163 | } |
||
164 | |||
165 | if ($current instanceof LeafNode) { |
||
166 | break 1; |
||
167 | } |
||
168 | } |
||
169 | |||
170 | return $current; |
||
171 | } |
||
172 | |||
173 | abstract protected function split(array $samples, array $targets): DecisionNode; |
||
174 | |||
175 | abstract protected function terminate(array $targets): BinaryNode; |
||
176 | } |
||
177 |
This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.
Consider making the comparison explicit by using
empty(..)
or! empty(...)
instead.