Complex classes like DecisionTree often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes. You can also have a look at the cohesion graph to spot any un-connected, or weakly-connected components.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
While breaking up the class, it is a good idea to analyze how other classes use DecisionTree, and based on these observations, apply Extract Interface, too.
| 1 | <?php |
||
| 13 | class DecisionTree implements Classifier |
||
| 14 | { |
||
| 15 | use Trainable, Predictable; |
||
| 16 | |||
| 17 | const CONTINUOUS = 1; |
||
| 18 | const NOMINAL = 2; |
||
| 19 | |||
| 20 | /** |
||
| 21 | * @var array |
||
| 22 | */ |
||
| 23 | protected $columnTypes = []; |
||
| 24 | |||
| 25 | /** |
||
| 26 | * @var array |
||
| 27 | */ |
||
| 28 | private $labels = []; |
||
| 29 | |||
| 30 | /** |
||
| 31 | * @var int |
||
| 32 | */ |
||
| 33 | private $featureCount = 0; |
||
| 34 | |||
| 35 | /** |
||
| 36 | * @var DecisionTreeLeaf |
||
| 37 | */ |
||
| 38 | protected $tree = null; |
||
| 39 | |||
| 40 | /** |
||
| 41 | * @var int |
||
| 42 | */ |
||
| 43 | protected $maxDepth; |
||
| 44 | |||
| 45 | public static $categoricalColumnMinimumUniqueValueCount = 0.2; |
||
| 46 | |||
| 47 | /** |
||
| 48 | * @var int |
||
| 49 | */ |
||
| 50 | public $actualDepth = 0; |
||
| 51 | |||
| 52 | /** |
||
| 53 | * @var int |
||
| 54 | */ |
||
| 55 | private $numUsableFeatures = 0; |
||
| 56 | |||
| 57 | /** |
||
| 58 | * @var array |
||
| 59 | */ |
||
| 60 | private $selectedFeatures; |
||
| 61 | |||
| 62 | /** |
||
| 63 | * @var array |
||
| 64 | */ |
||
| 65 | private $featureImportances = null; |
||
| 66 | |||
| 67 | /** |
||
| 68 | * |
||
| 69 | * @var array |
||
| 70 | */ |
||
| 71 | private $columnNames = null; |
||
| 72 | |||
| 73 | /** |
||
| 74 | * @param int $maxDepth |
||
| 75 | */ |
||
| 76 | public function __construct(int $maxDepth = 10) |
||
| 80 | |||
| 81 | /** |
||
| 82 | * @param array $samples |
||
| 83 | * @param array $targets |
||
| 84 | */ |
||
| 85 | public function train(array $samples, array $targets) |
||
| 86 | { |
||
| 87 | $this->samples = array_merge($this->samples, $samples); |
||
| 88 | $this->targets = array_merge($this->targets, $targets); |
||
| 89 | |||
| 90 | $this->featureCount = count($this->samples[0]); |
||
| 91 | if (count($this->columnTypes) != $this->featureCount) { |
||
| 92 | $this->columnTypes = self::getColumnTypes($this->samples); |
||
| 93 | } elseif (count(array_filter($this->columnTypes, 'is_null')) > 0) { |
||
| 94 | foreach (self::getColumnTypes($this->samples) as $key => $value) { |
||
| 95 | if (is_null($this->columnTypes[$key])) { |
||
| 96 | $this->columnTypes[$key] = $value; |
||
| 97 | } |
||
| 98 | } |
||
| 99 | unset($key, $value); |
||
| 100 | } |
||
| 101 | $this->labels = array_keys(array_count_values($this->targets)); |
||
| 102 | $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1)); |
||
| 103 | |||
| 104 | // Each time the tree is trained, feature importances are reset so that |
||
| 105 | // we will have to compute it again depending on the new data |
||
| 106 | $this->featureImportances = null; |
||
|
|
|||
| 107 | |||
| 108 | // If column names are given or computed before, then there is no |
||
| 109 | // need to init it and accidentally remove the previous given names |
||
| 110 | if ($this->columnNames === null) { |
||
| 111 | $this->columnNames = range(0, $this->featureCount - 1); |
||
| 112 | } elseif (count($this->columnNames) > $this->featureCount) { |
||
| 113 | $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount); |
||
| 114 | } elseif (count($this->columnNames) < $this->featureCount) { |
||
| 115 | $this->columnNames = array_merge($this->columnNames, |
||
| 116 | range(count($this->columnNames), $this->featureCount - 1) |
||
| 117 | ); |
||
| 118 | } |
||
| 119 | } |
||
| 120 | |||
| 121 | /** |
||
| 122 | * @param array $samples |
||
| 123 | * |
||
| 124 | * @return array |
||
| 125 | */ |
||
| 126 | public static function getColumnTypes(array $samples) : array |
||
| 138 | |||
| 139 | /** |
||
| 140 | * @param array $records |
||
| 141 | * @param int $depth |
||
| 142 | * |
||
| 143 | * @return DecisionTreeLeaf |
||
| 144 | */ |
||
| 145 | protected function getSplitLeaf(array $records, int $depth = 0) : DecisionTreeLeaf |
||
| 202 | |||
| 203 | /** |
||
| 204 | * @param array $records |
||
| 205 | * |
||
| 206 | * @return DecisionTreeLeaf|null |
||
| 207 | */ |
||
| 208 | protected function getBestSplit(array $records) : DecisionTreeLeaf |
||
| 209 | { |
||
| 210 | $targets = array_intersect_key($this->targets, array_flip($records)); |
||
| 211 | $samples = array_intersect_key($this->samples, array_flip($records)); |
||
| 212 | $samples = array_combine($records, $this->preprocess($samples)); |
||
| 213 | $bestGiniVal = 1; |
||
| 214 | $bestSplit = null; |
||
| 215 | $features = $this->getSelectedFeatures(); |
||
| 216 | foreach ($features as $i) { |
||
| 217 | $colValues = []; |
||
| 218 | foreach ($samples as $index => $row) { |
||
| 219 | if (!is_null($row[$i])) { |
||
| 220 | $colValues[$index] = $row[$i]; |
||
| 221 | } |
||
| 222 | } |
||
| 223 | $counts = array_count_values($colValues); |
||
| 224 | arsort($counts); |
||
| 225 | $baseValue = key($counts); |
||
| 226 | $gini = $this->getGiniIndex($baseValue, $colValues, $targets); |
||
| 227 | if (($bestSplit === null) || ($bestGiniVal > $gini)) { |
||
| 228 | $split = new DecisionTreeLeaf(); |
||
| 229 | $split->value = $baseValue; |
||
| 230 | $split->giniIndex = $gini; |
||
| 231 | $split->columnIndex = $i; |
||
| 232 | $split->isContinuous = $this->columnTypes[$i] == self::CONTINUOUS; |
||
| 233 | $split->records = $records; |
||
| 234 | |||
| 235 | // If a numeric column is to be selected, then |
||
| 236 | // the original numeric value and the selected operator |
||
| 237 | // will also be saved into the leaf for future access |
||
| 238 | if ($this->columnTypes[$i] == self::CONTINUOUS) { |
||
| 239 | $matches = []; |
||
| 240 | preg_match("/^([<>=]{1,2})\s*(.*)/", strval($split->value), $matches); |
||
| 241 | $split->operator = $matches[1]; |
||
| 242 | $split->numericValue = floatval($matches[2]); |
||
| 243 | } |
||
| 244 | |||
| 245 | $bestSplit = $split; |
||
| 246 | $bestGiniVal = $gini; |
||
| 247 | } |
||
| 248 | } |
||
| 249 | |||
| 250 | return $bestSplit; |
||
| 251 | } |
||
| 252 | |||
| 253 | /** |
||
| 254 | * Returns available features/columns to the tree for the decision making |
||
| 255 | * process. <br> |
||
| 256 | * |
||
| 257 | * If a number is given with setNumFeatures() method, then a random selection |
||
| 258 | * of features up to this number is returned. <br> |
||
| 259 | * |
||
| 260 | * If some features are manually selected by use of setSelectedFeatures(), |
||
| 261 | * then only these features are returned <br> |
||
| 262 | * |
||
| 263 | * If any of above methods were not called beforehand, then all features |
||
| 264 | * are returned by default. |
||
| 265 | * |
||
| 266 | * @return array |
||
| 267 | */ |
||
| 268 | protected function getSelectedFeatures() : array |
||
| 289 | |||
| 290 | /** |
||
| 291 | * @param mixed $baseValue |
||
| 292 | * @param array $colValues |
||
| 293 | * @param array $targets |
||
| 294 | * |
||
| 295 | * @return float |
||
| 296 | */ |
||
| 297 | public function getGiniIndex($baseValue, array $colValues, array $targets) : float |
||
| 328 | |||
| 329 | /** |
||
| 330 | * @param array $samples |
||
| 331 | * |
||
| 332 | * @return array |
||
| 333 | */ |
||
| 334 | protected function preprocess(array $samples) : array |
||
| 357 | |||
| 358 | /** |
||
| 359 | * @param array $columnValues |
||
| 360 | * |
||
| 361 | * @return bool |
||
| 362 | */ |
||
| 363 | protected static function isCategoricalColumn(array $columnValues) : bool |
||
| 387 | |||
| 388 | /** |
||
| 389 | * This method is used to set number of columns to be used |
||
| 390 | * when deciding a split at an internal node of the tree. <br> |
||
| 391 | * If the value is given 0, then all features are used (default behaviour), |
||
| 392 | * otherwise the given value will be used as a maximum for number of columns |
||
| 393 | * randomly selected for each split operation. |
||
| 394 | * |
||
| 395 | * @param int $numFeatures |
||
| 396 | * |
||
| 397 | * @return $this |
||
| 398 | * |
||
| 399 | * @throws InvalidArgumentException |
||
| 400 | */ |
||
| 401 | public function setNumFeatures(int $numFeatures) |
||
| 411 | |||
| 412 | /** |
||
| 413 | * Used to set predefined features to consider while deciding which column to use for a split |
||
| 414 | * |
||
| 415 | * @param array $selectedFeatures |
||
| 416 | */ |
||
| 417 | protected function setSelectedFeatures(array $selectedFeatures) |
||
| 421 | |||
| 422 | /** |
||
| 423 | * A string array to represent columns. Useful when HTML output or |
||
| 424 | * column importances are desired to be inspected. |
||
| 425 | * |
||
| 426 | * @param array $names |
||
| 427 | * |
||
| 428 | * @return $this |
||
| 429 | * |
||
| 430 | * @throws InvalidArgumentException |
||
| 431 | */ |
||
| 432 | public function setColumnNames(array $names) |
||
| 442 | |||
| 443 | /** |
||
| 444 | * @return string |
||
| 445 | */ |
||
| 446 | public function getHtml() |
||
| 450 | |||
| 451 | /** |
||
| 452 | * This will return an array including an importance value for |
||
| 453 | * each column in the given dataset. The importance values are |
||
| 454 | * normalized and their total makes 1.<br/> |
||
| 455 | * |
||
| 456 | * @return array |
||
| 457 | */ |
||
| 458 | public function getFeatureImportances() |
||
| 488 | |||
| 489 | /** |
||
| 490 | * Collects and returns an array of internal nodes that use the given |
||
| 491 | * column as a split criterion |
||
| 492 | * |
||
| 493 | * @param int $column |
||
| 494 | * @param DecisionTreeLeaf $node |
||
| 495 | * |
||
| 496 | * @return array |
||
| 497 | */ |
||
| 498 | protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node) : array |
||
| 523 | |||
| 524 | /** |
||
| 525 | * @param array $sample |
||
| 526 | * |
||
| 527 | * @return mixed |
||
| 528 | */ |
||
| 529 | protected function predictSample(array $sample) |
||
| 546 | |||
| 547 | /** |
||
| 548 | * @return integer[]|null[] |
||
| 549 | */ |
||
| 550 | public function getInstanceColumnTypes() { |
||
| 553 | |||
| 554 | /** |
||
| 555 | * @param integer[]|null[] $columnTypes |
||
| 556 | */ |
||
| 557 | public function setInstanceColumnTypes(array $columnTypes) { |
||
| 560 | |||
| 561 | /** |
||
| 562 | * @param array $values |
||
| 563 | * |
||
| 564 | * @return array |
||
| 565 | */ |
||
| 566 | protected static function arrayCountValues(array $values) { |
||
| 576 | } |
||
| 577 |
Our type inference engine has found an assignment to a property that is incompatible with the declared type of that property.
Either this assignment is in error or the assigned type should be added to the documentation/type hint for that property..