Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like DecisionStump often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes. You can also have a look at the cohesion graph to spot any un-connected, or weakly-connected components.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
While breaking up the class, it is a good idea to analyze how other classes use DecisionStump, and based on these observations, apply Extract Interface, too.
| 1 | <?php |
||
| 12 | class DecisionStump extends WeightedClassifier |
||
| 13 | { |
||
| 14 | use Predictable, OneVsRest; |
||
| 15 | |||
| 16 | const AUTO_SELECT = -1; |
||
| 17 | |||
| 18 | /** |
||
| 19 | * @var int |
||
| 20 | */ |
||
| 21 | protected $givenColumnIndex; |
||
| 22 | |||
| 23 | /** |
||
| 24 | * @var array |
||
| 25 | */ |
||
| 26 | protected $binaryLabels; |
||
| 27 | |||
| 28 | /** |
||
| 29 | * Lowest error rate obtained while training/optimizing the model |
||
| 30 | * |
||
| 31 | * @var float |
||
| 32 | */ |
||
| 33 | protected $trainingErrorRate; |
||
| 34 | |||
| 35 | /** |
||
| 36 | * @var int |
||
| 37 | */ |
||
| 38 | protected $column; |
||
| 39 | |||
| 40 | /** |
||
| 41 | * @var mixed |
||
| 42 | */ |
||
| 43 | protected $value; |
||
| 44 | |||
| 45 | /** |
||
| 46 | * @var string |
||
| 47 | */ |
||
| 48 | protected $operator; |
||
| 49 | |||
| 50 | /** |
||
| 51 | * @var array |
||
| 52 | */ |
||
| 53 | protected $columnTypes; |
||
| 54 | |||
| 55 | /** |
||
| 56 | * @var int |
||
| 57 | */ |
||
| 58 | protected $featureCount; |
||
| 59 | |||
| 60 | /** |
||
| 61 | * @var float |
||
| 62 | */ |
||
| 63 | protected $numSplitCount = 100.0; |
||
| 64 | |||
| 65 | /** |
||
| 66 | * Distribution of samples in the leaves |
||
| 67 | * |
||
| 68 | * @var array |
||
| 69 | */ |
||
| 70 | protected $prob; |
||
| 71 | |||
| 72 | /** |
||
| 73 | * A DecisionStump classifier is a one-level deep DecisionTree. It is generally |
||
| 74 | * used with ensemble algorithms as in the weak classifier role. <br> |
||
| 75 | * |
||
| 76 | * If columnIndex is given, then the stump tries to produce a decision node |
||
| 77 | * on this column, otherwise in cases given the value of -1, the stump itself |
||
| 78 | * decides which column to take for the decision (Default DecisionTree behaviour) |
||
| 79 | * |
||
| 80 | * @param int $columnIndex |
||
| 81 | */ |
||
| 82 | public function __construct(int $columnIndex = self::AUTO_SELECT) |
||
| 83 | { |
||
| 84 | $this->givenColumnIndex = $columnIndex; |
||
| 85 | } |
||
| 86 | |||
| 87 | /** |
||
| 88 | * @param array $samples |
||
| 89 | * @param array $targets |
||
| 90 | * @throws \Exception |
||
| 91 | */ |
||
| 92 | protected function trainBinary(array $samples, array $targets, array $labels) |
||
| 93 | { |
||
| 94 | $this->binaryLabels = $labels; |
||
| 95 | $this->featureCount = count($samples[0]); |
||
| 96 | |||
| 97 | // If a column index is given, it should be among the existing columns |
||
| 98 | if ($this->givenColumnIndex > count($samples[0]) - 1) { |
||
| 99 | $this->givenColumnIndex = self::AUTO_SELECT; |
||
| 100 | } |
||
| 101 | |||
| 102 | // Check the size of the weights given. |
||
| 103 | // If none given, then assign 1 as a weight to each sample |
||
| 104 | if ($this->weights) { |
||
|
|
|||
| 105 | $numWeights = count($this->weights); |
||
| 106 | if ($numWeights != count($samples)) { |
||
| 107 | throw new \Exception("Number of sample weights does not match with number of samples"); |
||
| 108 | } |
||
| 109 | } else { |
||
| 110 | $this->weights = array_fill(0, count($samples), 1); |
||
| 111 | } |
||
| 112 | |||
| 113 | // Determine type of each column as either "continuous" or "nominal" |
||
| 114 | $this->columnTypes = DecisionTree::getColumnTypes($samples); |
||
| 115 | |||
| 116 | // Try to find the best split in the columns of the dataset |
||
| 117 | // by calculating error rate for each split point in each column |
||
| 118 | $columns = range(0, count($samples[0]) - 1); |
||
| 119 | if ($this->givenColumnIndex != self::AUTO_SELECT) { |
||
| 120 | $columns = [$this->givenColumnIndex]; |
||
| 121 | } |
||
| 122 | |||
| 123 | $bestSplit = [ |
||
| 124 | 'value' => 0, 'operator' => '', |
||
| 125 | 'prob' => [], 'column' => 0, |
||
| 126 | 'trainingErrorRate' => 1.0]; |
||
| 127 | foreach ($columns as $col) { |
||
| 128 | if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) { |
||
| 129 | $split = $this->getBestNumericalSplit($samples, $targets, $col); |
||
| 130 | } else { |
||
| 131 | $split = $this->getBestNominalSplit($samples, $targets, $col); |
||
| 132 | } |
||
| 133 | |||
| 134 | if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) { |
||
| 135 | $bestSplit = $split; |
||
| 136 | } |
||
| 137 | } |
||
| 138 | |||
| 139 | // Assign determined best values to the stump |
||
| 140 | foreach ($bestSplit as $name => $value) { |
||
| 141 | $this->{$name} = $value; |
||
| 142 | } |
||
| 143 | } |
||
| 144 | |||
| 145 | /** |
||
| 146 | * While finding best split point for a numerical valued column, |
||
| 147 | * DecisionStump looks for equally distanced values between minimum and maximum |
||
| 148 | * values in the column. Given <i>$count</i> value determines how many split |
||
| 149 | * points to be probed. The more split counts, the better performance but |
||
| 150 | * worse processing time (Default value is 10.0) |
||
| 151 | * |
||
| 152 | * @param float $count |
||
| 153 | */ |
||
| 154 | public function setNumericalSplitCount(float $count) |
||
| 155 | { |
||
| 156 | $this->numSplitCount = $count; |
||
| 157 | } |
||
| 158 | |||
| 159 | /** |
||
| 160 | * Determines best split point for the given column |
||
| 161 | * |
||
| 162 | * @param array $samples |
||
| 163 | * @param array $targets |
||
| 164 | * @param int $col |
||
| 165 | * |
||
| 166 | * @return array |
||
| 167 | */ |
||
| 168 | protected function getBestNumericalSplit(array $samples, array $targets, int $col) |
||
| 169 | { |
||
| 170 | $values = array_column($samples, $col); |
||
| 171 | // Trying all possible points may be accomplished in two general ways: |
||
| 172 | // 1- Try all values in the $samples array ($values) |
||
| 173 | // 2- Artificially split the range of values into several parts and try them |
||
| 174 | // We choose the second one because it is faster in larger datasets |
||
| 175 | $minValue = min($values); |
||
| 176 | $maxValue = max($values); |
||
| 177 | $stepSize = ($maxValue - $minValue) / $this->numSplitCount; |
||
| 178 | |||
| 179 | $split = null; |
||
| 180 | |||
| 181 | foreach (['<=', '>'] as $operator) { |
||
| 182 | // Before trying all possible split points, let's first try |
||
| 183 | // the average value for the cut point |
||
| 184 | $threshold = array_sum($values) / (float) count($values); |
||
| 185 | list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values); |
||
| 186 | View Code Duplication | if ($split == null || $errorRate < $split['trainingErrorRate']) { |
|
| 187 | $split = ['value' => $threshold, 'operator' => $operator, |
||
| 188 | 'prob' => $prob, 'column' => $col, |
||
| 189 | 'trainingErrorRate' => $errorRate]; |
||
| 190 | } |
||
| 191 | |||
| 192 | // Try other possible points one by one |
||
| 193 | for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) { |
||
| 194 | $threshold = (float)$step; |
||
| 195 | list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values); |
||
| 196 | View Code Duplication | if ($errorRate < $split['trainingErrorRate']) { |
|
| 197 | $split = ['value' => $threshold, 'operator' => $operator, |
||
| 198 | 'prob' => $prob, 'column' => $col, |
||
| 199 | 'trainingErrorRate' => $errorRate]; |
||
| 200 | } |
||
| 201 | }// for |
||
| 202 | } |
||
| 203 | |||
| 204 | return $split; |
||
| 205 | } |
||
| 206 | |||
| 207 | /** |
||
| 208 | * @param array $samples |
||
| 209 | * @param array $targets |
||
| 210 | * @param int $col |
||
| 211 | * |
||
| 212 | * @return array |
||
| 213 | */ |
||
| 214 | protected function getBestNominalSplit(array $samples, array $targets, int $col) : array |
||
| 215 | { |
||
| 216 | $values = array_column($samples, $col); |
||
| 217 | $valueCounts = array_count_values($values); |
||
| 218 | $distinctVals= array_keys($valueCounts); |
||
| 219 | |||
| 220 | $split = null; |
||
| 221 | |||
| 222 | foreach (['=', '!='] as $operator) { |
||
| 223 | foreach ($distinctVals as $val) { |
||
| 224 | list($errorRate, $prob) = $this->calculateErrorRate($targets, $val, $operator, $values); |
||
| 225 | |||
| 226 | View Code Duplication | if ($split == null || $split['trainingErrorRate'] < $errorRate) { |
|
| 227 | $split = ['value' => $val, 'operator' => $operator, |
||
| 228 | 'prob' => $prob, 'column' => $col, |
||
| 229 | 'trainingErrorRate' => $errorRate]; |
||
| 230 | } |
||
| 231 | } |
||
| 232 | } |
||
| 233 | |||
| 234 | return $split; |
||
| 235 | } |
||
| 236 | |||
| 237 | |||
| 238 | /** |
||
| 239 | * |
||
| 240 | * @param type $leftValue |
||
| 241 | * @param type $operator |
||
| 242 | * @param type $rightValue |
||
| 243 | * |
||
| 244 | * @return boolean |
||
| 245 | */ |
||
| 246 | protected function evaluate($leftValue, $operator, $rightValue) |
||
| 247 | { |
||
| 248 | switch ($operator) { |
||
| 249 | case '>': return $leftValue > $rightValue; |
||
| 250 | case '>=': return $leftValue >= $rightValue; |
||
| 251 | case '<': return $leftValue < $rightValue; |
||
| 252 | case '<=': return $leftValue <= $rightValue; |
||
| 253 | case '=': return $leftValue === $rightValue; |
||
| 254 | case '!=': |
||
| 255 | case '<>': return $leftValue !== $rightValue; |
||
| 256 | } |
||
| 257 | |||
| 258 | return false; |
||
| 259 | } |
||
| 260 | |||
| 261 | /** |
||
| 262 | * Calculates the ratio of wrong predictions based on the new threshold |
||
| 263 | * value given as the parameter |
||
| 264 | * |
||
| 265 | * @param array $targets |
||
| 266 | * @param float $threshold |
||
| 267 | * @param string $operator |
||
| 268 | * @param array $values |
||
| 269 | * |
||
| 270 | * @return array |
||
| 271 | */ |
||
| 272 | protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values) : array |
||
| 273 | { |
||
| 274 | $wrong = 0.0; |
||
| 275 | $prob = []; |
||
| 276 | $leftLabel = $this->binaryLabels[0]; |
||
| 277 | $rightLabel= $this->binaryLabels[1]; |
||
| 278 | |||
| 279 | foreach ($values as $index => $value) { |
||
| 280 | if ($this->evaluate($value, $operator, $threshold)) { |
||
| 281 | $predicted = $leftLabel; |
||
| 282 | } else { |
||
| 283 | $predicted = $rightLabel; |
||
| 284 | } |
||
| 285 | |||
| 286 | $target = $targets[$index]; |
||
| 287 | if (strval($predicted) != strval($targets[$index])) { |
||
| 288 | $wrong += $this->weights[$index]; |
||
| 289 | } |
||
| 290 | |||
| 291 | if (! isset($prob[$predicted][$target])) { |
||
| 292 | $prob[$predicted][$target] = 0; |
||
| 293 | } |
||
| 294 | $prob[$predicted][$target]++; |
||
| 295 | } |
||
| 296 | |||
| 297 | // Calculate probabilities: Proportion of labels in each leaf |
||
| 298 | $dist = array_combine($this->binaryLabels, array_fill(0, 2, 0.0)); |
||
| 299 | foreach ($prob as $leaf => $counts) { |
||
| 300 | $leafTotal = (float)array_sum($prob[$leaf]); |
||
| 301 | foreach ($counts as $label => $count) { |
||
| 302 | if (strval($leaf) == strval($label)) { |
||
| 303 | $dist[$leaf] = $count / $leafTotal; |
||
| 304 | } |
||
| 305 | } |
||
| 306 | } |
||
| 307 | |||
| 308 | return [$wrong / (float) array_sum($this->weights), $dist]; |
||
| 309 | } |
||
| 310 | |||
| 311 | /** |
||
| 312 | * Returns the probability of the sample of belonging to the given label |
||
| 313 | * |
||
| 314 | * Probability of a sample is calculated as the proportion of the label |
||
| 315 | * within the labels of the training samples in the decision node |
||
| 316 | * |
||
| 317 | * @param array $sample |
||
| 318 | * @param mixed $label |
||
| 319 | * |
||
| 320 | * @return float |
||
| 321 | */ |
||
| 322 | protected function predictProbability(array $sample, $label) : float |
||
| 323 | { |
||
| 324 | $predicted = $this->predictSampleBinary($sample); |
||
| 325 | if (strval($predicted) == strval($label)) { |
||
| 326 | return $this->prob[$label]; |
||
| 327 | } |
||
| 328 | |||
| 329 | return 0.0; |
||
| 330 | } |
||
| 331 | |||
| 332 | /** |
||
| 333 | * @param array $sample |
||
| 334 | * |
||
| 335 | * @return mixed |
||
| 336 | */ |
||
| 337 | protected function predictSampleBinary(array $sample) |
||
| 338 | { |
||
| 339 | if ($this->evaluate($sample[$this->column], $this->operator, $this->value)) { |
||
| 340 | return $this->binaryLabels[0]; |
||
| 341 | } |
||
| 342 | |||
| 343 | return $this->binaryLabels[1]; |
||
| 344 | } |
||
| 345 | |||
| 346 | /** |
||
| 347 | * @return void |
||
| 348 | */ |
||
| 349 | protected function resetBinary() |
||
| 350 | { |
||
| 351 | } |
||
| 352 | |||
| 353 | /** |
||
| 354 | * @return string |
||
| 355 | */ |
||
| 356 | public function __toString() |
||
| 362 | } |
||
| 363 |
This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.
Consider making the comparison explicit by using
empty(..)or! empty(...)instead.