php-ai /
php-ml
| 1 | <?php |
||
| 2 | |||
| 3 | declare(strict_types=1); |
||
| 4 | |||
| 5 | namespace Phpml\Tree; |
||
| 6 | |||
| 7 | use Phpml\Exception\InvalidArgumentException; |
||
| 8 | use Phpml\Tree\Node\BinaryNode; |
||
| 9 | use Phpml\Tree\Node\DecisionNode; |
||
| 10 | use Phpml\Tree\Node\LeafNode; |
||
| 11 | |||
| 12 | abstract class CART |
||
| 13 | { |
||
| 14 | /** |
||
| 15 | * @var DecisionNode|null |
||
| 16 | */ |
||
| 17 | protected $root; |
||
| 18 | |||
| 19 | /** |
||
| 20 | * @var int |
||
| 21 | */ |
||
| 22 | protected $maxDepth; |
||
| 23 | |||
| 24 | /** |
||
| 25 | * @var int |
||
| 26 | */ |
||
| 27 | protected $maxLeafSize; |
||
| 28 | |||
| 29 | /** |
||
| 30 | * @var float |
||
| 31 | */ |
||
| 32 | protected $minPurityIncrease; |
||
| 33 | |||
| 34 | /** |
||
| 35 | * @var int |
||
| 36 | */ |
||
| 37 | protected $featureCount; |
||
| 38 | |||
| 39 | public function __construct(int $maxDepth = PHP_INT_MAX, int $maxLeafSize = 3, float $minPurityIncrease = 0.) |
||
| 40 | { |
||
| 41 | if ($maxDepth < 1) { |
||
| 42 | throw new InvalidArgumentException('Max depth must be greater than 0'); |
||
| 43 | } |
||
| 44 | |||
| 45 | if ($maxLeafSize < 1) { |
||
| 46 | throw new InvalidArgumentException('Max leaf size must be greater than 0'); |
||
| 47 | } |
||
| 48 | |||
| 49 | if ($minPurityIncrease < 0.) { |
||
| 50 | throw new InvalidArgumentException('Min purity increase must be equal or greater than 0'); |
||
| 51 | } |
||
| 52 | |||
| 53 | $this->maxDepth = $maxDepth; |
||
| 54 | $this->maxLeafSize = $maxLeafSize; |
||
| 55 | $this->minPurityIncrease = $minPurityIncrease; |
||
| 56 | } |
||
| 57 | |||
| 58 | public function root(): ?DecisionNode |
||
| 59 | { |
||
| 60 | return $this->root; |
||
| 61 | } |
||
| 62 | |||
| 63 | public function height(): int |
||
| 64 | { |
||
| 65 | return $this->root !== null ? $this->root->height() : 0; |
||
| 66 | } |
||
| 67 | |||
| 68 | public function balance(): int |
||
| 69 | { |
||
| 70 | return $this->root !== null ? $this->root->balance() : 0; |
||
| 71 | } |
||
| 72 | |||
| 73 | public function bare(): bool |
||
| 74 | { |
||
| 75 | return $this->root === null; |
||
| 76 | } |
||
| 77 | |||
| 78 | public function grow(array $samples, array $targets): void |
||
| 79 | { |
||
| 80 | $this->featureCount = count($samples[0]); |
||
| 81 | $depth = 1; |
||
| 82 | $this->root = $this->split($samples, $targets); |
||
| 83 | $stack = [[$this->root, $depth]]; |
||
| 84 | |||
| 85 | while ($stack) { |
||
|
0 ignored issues
–
show
|
|||
| 86 | [$current, $depth] = array_pop($stack) ?? []; |
||
| 87 | |||
| 88 | [$left, $right] = $current->groups(); |
||
| 89 | |||
| 90 | $current->cleanup(); |
||
| 91 | |||
| 92 | $depth++; |
||
| 93 | |||
| 94 | if ($left[1] === [] || $right[1] === []) { |
||
| 95 | $node = $this->terminate(array_merge($left[1], $right[1])); |
||
| 96 | |||
| 97 | $current->attachLeft($node); |
||
| 98 | $current->attachRight($node); |
||
| 99 | |||
| 100 | continue 1; |
||
| 101 | } |
||
| 102 | |||
| 103 | if ($depth >= $this->maxDepth) { |
||
| 104 | $current->attachLeft($this->terminate($left[1])); |
||
| 105 | $current->attachRight($this->terminate($right[1])); |
||
| 106 | |||
| 107 | continue 1; |
||
| 108 | } |
||
| 109 | |||
| 110 | if (count($left[1]) > $this->maxLeafSize) { |
||
| 111 | $node = $this->split($left[0], $left[1]); |
||
| 112 | |||
| 113 | if ($node->purityIncrease() + 1e-8 > $this->minPurityIncrease) { |
||
| 114 | $current->attachLeft($node); |
||
| 115 | |||
| 116 | $stack[] = [$node, $depth]; |
||
| 117 | } else { |
||
| 118 | $current->attachLeft($this->terminate($left[1])); |
||
| 119 | } |
||
| 120 | } else { |
||
| 121 | $current->attachLeft($this->terminate($left[1])); |
||
| 122 | } |
||
| 123 | |||
| 124 | if (count($right[1]) > $this->maxLeafSize) { |
||
| 125 | $node = $this->split($right[0], $right[1]); |
||
| 126 | |||
| 127 | if ($node->purityIncrease() + 1e-8 > $this->minPurityIncrease) { |
||
| 128 | $current->attachRight($node); |
||
| 129 | |||
| 130 | $stack[] = [$node, $depth]; |
||
| 131 | } else { |
||
| 132 | $current->attachRight($this->terminate($right[1])); |
||
| 133 | } |
||
| 134 | } else { |
||
| 135 | $current->attachRight($this->terminate($right[1])); |
||
| 136 | } |
||
| 137 | } |
||
| 138 | } |
||
| 139 | |||
| 140 | public function search(array $sample): ?BinaryNode |
||
| 141 | { |
||
| 142 | $current = $this->root; |
||
| 143 | |||
| 144 | while ($current) { |
||
| 145 | if ($current instanceof DecisionNode) { |
||
| 146 | $value = $current->value(); |
||
| 147 | |||
| 148 | if (is_string($value)) { |
||
| 149 | if ($sample[$current->column()] === $value) { |
||
| 150 | $current = $current->left(); |
||
| 151 | } else { |
||
| 152 | $current = $current->right(); |
||
| 153 | } |
||
| 154 | } else { |
||
| 155 | if ($sample[$current->column()] < $value) { |
||
| 156 | $current = $current->left(); |
||
| 157 | } else { |
||
| 158 | $current = $current->right(); |
||
| 159 | } |
||
| 160 | } |
||
| 161 | |||
| 162 | continue 1; |
||
| 163 | } |
||
| 164 | |||
| 165 | if ($current instanceof LeafNode) { |
||
| 166 | break 1; |
||
| 167 | } |
||
| 168 | } |
||
| 169 | |||
| 170 | return $current; |
||
| 171 | } |
||
| 172 | |||
| 173 | abstract protected function split(array $samples, array $targets): DecisionNode; |
||
| 174 | |||
| 175 | abstract protected function terminate(array $targets): BinaryNode; |
||
| 176 | } |
||
| 177 |
This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.
Consider making the comparison explicit by using
empty(..)or! empty(...)instead.