Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
| 1 | <?php |
||
| 13 | class DecisionStump extends WeightedClassifier |
||
| 14 | { |
||
| 15 | use Predictable, OneVsRest; |
||
| 16 | |||
| 17 | const AUTO_SELECT = -1; |
||
| 18 | |||
| 19 | /** |
||
| 20 | * @var int |
||
| 21 | */ |
||
| 22 | protected $givenColumnIndex; |
||
| 23 | |||
| 24 | /** |
||
| 25 | * @var array |
||
| 26 | */ |
||
| 27 | protected $binaryLabels; |
||
| 28 | |||
| 29 | /** |
||
| 30 | * Lowest error rate obtained while training/optimizing the model |
||
| 31 | * |
||
| 32 | * @var float |
||
| 33 | */ |
||
| 34 | protected $trainingErrorRate; |
||
| 35 | |||
| 36 | /** |
||
| 37 | * @var int |
||
| 38 | */ |
||
| 39 | protected $column; |
||
| 40 | |||
| 41 | /** |
||
| 42 | * @var mixed |
||
| 43 | */ |
||
| 44 | protected $value; |
||
| 45 | |||
| 46 | /** |
||
| 47 | * @var string |
||
| 48 | */ |
||
| 49 | protected $operator; |
||
| 50 | |||
| 51 | /** |
||
| 52 | * @var array |
||
| 53 | */ |
||
| 54 | protected $columnTypes; |
||
| 55 | |||
| 56 | /** |
||
| 57 | * @var int |
||
| 58 | */ |
||
| 59 | protected $featureCount; |
||
| 60 | |||
| 61 | /** |
||
| 62 | * @var float |
||
| 63 | */ |
||
| 64 | protected $numSplitCount = 100.0; |
||
| 65 | |||
| 66 | /** |
||
| 67 | * Distribution of samples in the leaves |
||
| 68 | * |
||
| 69 | * @var array |
||
| 70 | */ |
||
| 71 | protected $prob; |
||
| 72 | |||
| 73 | /** |
||
| 74 | * A DecisionStump classifier is a one-level deep DecisionTree. It is generally |
||
| 75 | * used with ensemble algorithms as in the weak classifier role. <br> |
||
| 76 | * |
||
| 77 | * If columnIndex is given, then the stump tries to produce a decision node |
||
| 78 | * on this column, otherwise in cases given the value of -1, the stump itself |
||
| 79 | * decides which column to take for the decision (Default DecisionTree behaviour) |
||
| 80 | * |
||
| 81 | * @param int $columnIndex |
||
| 82 | */ |
||
| 83 | public function __construct(int $columnIndex = self::AUTO_SELECT) |
||
| 84 | { |
||
| 85 | $this->givenColumnIndex = $columnIndex; |
||
| 86 | } |
||
| 87 | |||
| 88 | /** |
||
| 89 | * @param array $samples |
||
| 90 | * @param array $targets |
||
| 91 | * @param array $labels |
||
| 92 | * |
||
| 93 | * @throws \Exception |
||
| 94 | */ |
||
| 95 | protected function trainBinary(array $samples, array $targets, array $labels) |
||
| 96 | { |
||
| 97 | $this->binaryLabels = $labels; |
||
| 98 | $this->featureCount = count($samples[0]); |
||
| 99 | |||
| 100 | // If a column index is given, it should be among the existing columns |
||
| 101 | if ($this->givenColumnIndex > count($samples[0]) - 1) { |
||
| 102 | $this->givenColumnIndex = self::AUTO_SELECT; |
||
| 103 | } |
||
| 104 | |||
| 105 | // Check the size of the weights given. |
||
| 106 | // If none given, then assign 1 as a weight to each sample |
||
| 107 | if ($this->weights) { |
||
|
|
|||
| 108 | $numWeights = count($this->weights); |
||
| 109 | if ($numWeights != count($samples)) { |
||
| 110 | throw new \Exception('Number of sample weights does not match with number of samples'); |
||
| 111 | } |
||
| 112 | } else { |
||
| 113 | $this->weights = array_fill(0, count($samples), 1); |
||
| 114 | } |
||
| 115 | |||
| 116 | // Determine type of each column as either "continuous" or "nominal" |
||
| 117 | $this->columnTypes = DecisionTree::getColumnTypes($samples); |
||
| 118 | |||
| 119 | // Try to find the best split in the columns of the dataset |
||
| 120 | // by calculating error rate for each split point in each column |
||
| 121 | $columns = range(0, count($samples[0]) - 1); |
||
| 122 | if ($this->givenColumnIndex != self::AUTO_SELECT) { |
||
| 123 | $columns = [$this->givenColumnIndex]; |
||
| 124 | } |
||
| 125 | |||
| 126 | $bestSplit = [ |
||
| 127 | 'value' => 0, 'operator' => '', |
||
| 128 | 'prob' => [], 'column' => 0, |
||
| 129 | 'trainingErrorRate' => 1.0]; |
||
| 130 | foreach ($columns as $col) { |
||
| 131 | if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) { |
||
| 132 | $split = $this->getBestNumericalSplit($samples, $targets, $col); |
||
| 133 | } else { |
||
| 134 | $split = $this->getBestNominalSplit($samples, $targets, $col); |
||
| 135 | } |
||
| 136 | |||
| 137 | if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) { |
||
| 138 | $bestSplit = $split; |
||
| 139 | } |
||
| 140 | } |
||
| 141 | |||
| 142 | // Assign determined best values to the stump |
||
| 143 | foreach ($bestSplit as $name => $value) { |
||
| 144 | $this->{$name} = $value; |
||
| 145 | } |
||
| 146 | } |
||
| 147 | |||
| 148 | /** |
||
| 149 | * While finding best split point for a numerical valued column, |
||
| 150 | * DecisionStump looks for equally distanced values between minimum and maximum |
||
| 151 | * values in the column. Given <i>$count</i> value determines how many split |
||
| 152 | * points to be probed. The more split counts, the better performance but |
||
| 153 | * worse processing time (Default value is 10.0) |
||
| 154 | * |
||
| 155 | * @param float $count |
||
| 156 | */ |
||
| 157 | public function setNumericalSplitCount(float $count) |
||
| 158 | { |
||
| 159 | $this->numSplitCount = $count; |
||
| 160 | } |
||
| 161 | |||
| 162 | /** |
||
| 163 | * Determines best split point for the given column |
||
| 164 | * |
||
| 165 | * @param array $samples |
||
| 166 | * @param array $targets |
||
| 167 | * @param int $col |
||
| 168 | * |
||
| 169 | * @return array |
||
| 170 | */ |
||
| 171 | protected function getBestNumericalSplit(array $samples, array $targets, int $col) |
||
| 172 | { |
||
| 173 | $values = array_column($samples, $col); |
||
| 174 | // Trying all possible points may be accomplished in two general ways: |
||
| 175 | // 1- Try all values in the $samples array ($values) |
||
| 176 | // 2- Artificially split the range of values into several parts and try them |
||
| 177 | // We choose the second one because it is faster in larger datasets |
||
| 178 | $minValue = min($values); |
||
| 179 | $maxValue = max($values); |
||
| 180 | $stepSize = ($maxValue - $minValue) / $this->numSplitCount; |
||
| 181 | |||
| 182 | $split = null; |
||
| 183 | |||
| 184 | foreach (['<=', '>'] as $operator) { |
||
| 185 | // Before trying all possible split points, let's first try |
||
| 186 | // the average value for the cut point |
||
| 187 | $threshold = array_sum($values) / (float) count($values); |
||
| 188 | list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values); |
||
| 189 | View Code Duplication | if ($split == null || $errorRate < $split['trainingErrorRate']) { |
|
| 190 | $split = ['value' => $threshold, 'operator' => $operator, |
||
| 191 | 'prob' => $prob, 'column' => $col, |
||
| 192 | 'trainingErrorRate' => $errorRate]; |
||
| 193 | } |
||
| 194 | |||
| 195 | // Try other possible points one by one |
||
| 196 | for ($step = $minValue; $step <= $maxValue; $step += $stepSize) { |
||
| 197 | $threshold = (float) $step; |
||
| 198 | list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values); |
||
| 199 | View Code Duplication | if ($errorRate < $split['trainingErrorRate']) { |
|
| 200 | $split = ['value' => $threshold, 'operator' => $operator, |
||
| 201 | 'prob' => $prob, 'column' => $col, |
||
| 202 | 'trainingErrorRate' => $errorRate]; |
||
| 203 | } |
||
| 204 | }// for |
||
| 205 | } |
||
| 206 | |||
| 207 | return $split; |
||
| 208 | } |
||
| 209 | |||
| 210 | /** |
||
| 211 | * @param array $samples |
||
| 212 | * @param array $targets |
||
| 213 | * @param int $col |
||
| 214 | * |
||
| 215 | * @return array |
||
| 216 | */ |
||
| 217 | protected function getBestNominalSplit(array $samples, array $targets, int $col) : array |
||
| 218 | { |
||
| 219 | $values = array_column($samples, $col); |
||
| 220 | $valueCounts = array_count_values($values); |
||
| 221 | $distinctVals = array_keys($valueCounts); |
||
| 222 | |||
| 223 | $split = null; |
||
| 224 | |||
| 225 | foreach (['=', '!='] as $operator) { |
||
| 226 | foreach ($distinctVals as $val) { |
||
| 227 | list($errorRate, $prob) = $this->calculateErrorRate($targets, $val, $operator, $values); |
||
| 228 | |||
| 229 | View Code Duplication | if ($split == null || $split['trainingErrorRate'] < $errorRate) { |
|
| 230 | $split = ['value' => $val, 'operator' => $operator, |
||
| 231 | 'prob' => $prob, 'column' => $col, |
||
| 232 | 'trainingErrorRate' => $errorRate]; |
||
| 233 | } |
||
| 234 | } |
||
| 235 | } |
||
| 236 | |||
| 237 | return $split; |
||
| 238 | } |
||
| 239 | |||
| 240 | /** |
||
| 241 | * |
||
| 242 | * @param mixed $leftValue |
||
| 243 | * @param string $operator |
||
| 244 | * @param mixed $rightValue |
||
| 245 | * |
||
| 246 | * @return boolean |
||
| 247 | */ |
||
| 248 | protected function evaluate($leftValue, string $operator, $rightValue) |
||
| 249 | { |
||
| 250 | return CompareStrategyFactory::create($operator) |
||
| 251 | ->compare($leftValue, $rightValue); |
||
| 252 | } |
||
| 253 | |||
| 254 | /** |
||
| 255 | * Calculates the ratio of wrong predictions based on the new threshold |
||
| 256 | * value given as the parameter |
||
| 257 | * |
||
| 258 | * @param array $targets |
||
| 259 | * @param float $threshold |
||
| 260 | * @param string $operator |
||
| 261 | * @param array $values |
||
| 262 | * |
||
| 263 | * @return array |
||
| 264 | */ |
||
| 265 | protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values) : array |
||
| 266 | { |
||
| 267 | $wrong = 0.0; |
||
| 268 | $prob = []; |
||
| 269 | $leftLabel = $this->binaryLabels[0]; |
||
| 270 | $rightLabel = $this->binaryLabels[1]; |
||
| 271 | |||
| 272 | foreach ($values as $index => $value) { |
||
| 273 | if ($this->evaluate($value, $operator, $threshold)) { |
||
| 274 | $predicted = $leftLabel; |
||
| 275 | } else { |
||
| 276 | $predicted = $rightLabel; |
||
| 277 | } |
||
| 278 | |||
| 279 | $target = $targets[$index]; |
||
| 280 | if ((string) $predicted != (string) $targets[$index]) { |
||
| 281 | $wrong += $this->weights[$index]; |
||
| 282 | } |
||
| 283 | |||
| 284 | if (!isset($prob[$predicted][$target])) { |
||
| 285 | $prob[$predicted][$target] = 0; |
||
| 286 | } |
||
| 287 | ++$prob[$predicted][$target]; |
||
| 288 | } |
||
| 289 | |||
| 290 | // Calculate probabilities: Proportion of labels in each leaf |
||
| 291 | $dist = array_combine($this->binaryLabels, array_fill(0, 2, 0.0)); |
||
| 292 | foreach ($prob as $leaf => $counts) { |
||
| 293 | $leafTotal = (float) array_sum($prob[$leaf]); |
||
| 294 | foreach ($counts as $label => $count) { |
||
| 295 | if ((string) $leaf == (string) $label) { |
||
| 296 | $dist[$leaf] = $count / $leafTotal; |
||
| 297 | } |
||
| 298 | } |
||
| 299 | } |
||
| 300 | |||
| 301 | return [$wrong / (float) array_sum($this->weights), $dist]; |
||
| 302 | } |
||
| 303 | |||
| 304 | /** |
||
| 305 | * Returns the probability of the sample of belonging to the given label |
||
| 306 | * |
||
| 307 | * Probability of a sample is calculated as the proportion of the label |
||
| 308 | * within the labels of the training samples in the decision node |
||
| 309 | * |
||
| 310 | * @param array $sample |
||
| 311 | * @param mixed $label |
||
| 312 | * |
||
| 313 | * @return float |
||
| 314 | */ |
||
| 315 | protected function predictProbability(array $sample, $label) : float |
||
| 316 | { |
||
| 317 | $predicted = $this->predictSampleBinary($sample); |
||
| 318 | if ((string) $predicted == (string) $label) { |
||
| 319 | return $this->prob[$label]; |
||
| 320 | } |
||
| 321 | |||
| 322 | return 0.0; |
||
| 323 | } |
||
| 324 | |||
| 325 | /** |
||
| 326 | * @param array $sample |
||
| 327 | * |
||
| 328 | * @return mixed |
||
| 329 | */ |
||
| 330 | protected function predictSampleBinary(array $sample) |
||
| 331 | { |
||
| 332 | if ($this->evaluate($sample[$this->column], $this->operator, $this->value)) { |
||
| 333 | return $this->binaryLabels[0]; |
||
| 334 | } |
||
| 335 | |||
| 336 | return $this->binaryLabels[1]; |
||
| 337 | } |
||
| 338 | |||
| 339 | /** |
||
| 340 | * @return void |
||
| 341 | */ |
||
| 342 | protected function resetBinary() |
||
| 343 | { |
||
| 344 | } |
||
| 345 | |||
| 346 | /** |
||
| 347 | * @return string |
||
| 348 | */ |
||
| 349 | public function __toString() |
||
| 355 | } |
||
| 356 |
This check marks implicit conversions of arrays to boolean values in a comparison. While in PHP an empty array is considered to be equal (but not identical) to false, this is not always apparent.
Consider making the comparison explicit by using
empty(..)or! empty(...)instead.