### DecisionStump   A last analyzed 2019-06-22 20:56 UTC

#### Complexity

 Total Complexity 36

#### Size/Duplication

 Total Lines 304 Duplicated Lines 0 %

#### Importance

 Changes 0
Metric Value
wmc 36
eloc 118
dl 0
loc 304
rs 9.52
c 0
b 0
f 0

#### 10 Methods

Rating   Name   Duplication   Size   Complexity
A resetBinary() 0 2 1
A __construct() 0 3 1
B calculateErrorRate() 0 38 8
A setNumericalSplitCount() 0 3 1
B getBestNumericalSplit() 0 45 6
B trainBinary() 0 53 9
A getBestNominalSplit() 0 24 5
A predictProbability() 0 8 2
A __toString() 0 5 1
A predictSampleBinary() 0 7 2
 1 `` 78 ` *` 79 ` * If columnIndex is given, then the stump tries to produce a decision node` 80 ` * on this column, otherwise in cases given the value of -1, the stump itself` 81 ` * decides which column to take for the decision (Default DecisionTree behaviour)` 82 ` */` 83 ` public function __construct(int \$columnIndex = self::AUTO_SELECT)` 84 ` {` 85 ` \$this->givenColumnIndex = \$columnIndex;` 86 ` }` 87 88 ` public function __toString(): string` 89 ` {` 90 ` return "IF \${this}->column \${this}->operator \${this}->value ".` 91 ` 'THEN '.\$this->binaryLabels[0].' '.` 92 ` 'ELSE '.\$this->binaryLabels[1];` 93 ` }` 94 95 ` /**` 96 ` * While finding best split point for a numerical valued column,` 97 ` * DecisionStump looks for equally distanced values between minimum and maximum` 98 ` * values in the column. Given \$count value determines how many split` 99 ` * points to be probed. The more split counts, the better performance but` 100 ` * worse processing time (Default value is 10.0)` 101 ` */` 102 ` public function setNumericalSplitCount(float \$count): void` 103 ` {` 104 ` \$this->numSplitCount = \$count;` 105 ` }` 106 107 ` /**` 108 ` * @throws InvalidArgumentException` 109 ` */` 110 ` protected function trainBinary(array \$samples, array \$targets, array \$labels): void` 111 ` {` 112 ` \$this->binaryLabels = \$labels;` 113 ` \$this->featureCount = count(\$samples[0]);` 114 115 ` // If a column index is given, it should be among the existing columns` 116 ` if (\$this->givenColumnIndex > count(\$samples[0]) - 1) {` 117 ` \$this->givenColumnIndex = self::AUTO_SELECT;` 118 ` }` 119 120 ` // Check the size of the weights given.` 121 ` // If none given, then assign 1 as a weight to each sample` 122 ` if (count(\$this->weights) === 0) {` 123 ` \$this->weights = array_fill(0, count(\$samples), 1);` 124 ` } else {` 125 ` \$numWeights = count(\$this->weights);` 126 ` if (\$numWeights !== count(\$samples)) {` 127 ` throw new InvalidArgumentException('Number of sample weights does not match with number of samples');` 128 ` }` 129 ` }` 130 131 ` // Determine type of each column as either "continuous" or "nominal"` 132 ` \$this->columnTypes = DecisionTree::getColumnTypes(\$samples);` 133 134 ` // Try to find the best split in the columns of the dataset` 135 ` // by calculating error rate for each split point in each column` 136 ` \$columns = range(0, count(\$samples[0]) - 1);` 137 ` if (\$this->givenColumnIndex !== self::AUTO_SELECT) {` 138 ` \$columns = [\$this->givenColumnIndex];` 139 ` }` 140 141 ` \$bestSplit = [` 142 ` 'value' => 0,` 143 ` 'operator' => '',` 144 ` 'prob' => [],` 145 ` 'column' => 0,` 146 ` 'trainingErrorRate' => 1.0,` 147 ` ];` 148 ` foreach (\$columns as \$col) {` 149 ` if (\$this->columnTypes[\$col] == DecisionTree::CONTINUOUS) {` 150 ` \$split = \$this->getBestNumericalSplit(\$samples, \$targets, \$col);` 151 ` } else {` 152 ` \$split = \$this->getBestNominalSplit(\$samples, \$targets, \$col);` 153 ` }` 154 155 ` if (\$split['trainingErrorRate'] < \$bestSplit['trainingErrorRate']) {` 156 ` \$bestSplit = \$split;` 157 ` }` 158 ` }` 159 160 ` // Assign determined best values to the stump` 161 ` foreach (\$bestSplit as \$name => \$value) {` 162 ` \$this->{\$name} = \$value;` 163 ` }` 164 ` }` 165 166 ` /**` 167 ` * Determines best split point for the given column` 168 ` */` 169 ` protected function getBestNumericalSplit(array \$samples, array \$targets, int \$col): array` 170 ` {` 171 ` \$values = array_column(\$samples, \$col);` 172 ` // Trying all possible points may be accomplished in two general ways:` 173 ` // 1- Try all values in the \$samples array (\$values)` 174 ` // 2- Artificially split the range of values into several parts and try them` 175 ` // We choose the second one because it is faster in larger datasets` 176 ` \$minValue = min(\$values);` 177 ` \$maxValue = max(\$values);` 178 ` \$stepSize = (\$maxValue - \$minValue) / \$this->numSplitCount;` 179 180 ` \$split = [];` 181 182 ` foreach (['<=', '>'] as \$operator) {` 183 ` // Before trying all possible split points, let's first try` 184 ` // the average value for the cut point` 185 ` \$threshold = array_sum(\$values) / (float) count(\$values);` 186 ` [\$errorRate, \$prob] = \$this->calculateErrorRate(\$targets, \$threshold, \$operator, \$values);` 187 ` if (!isset(\$split['trainingErrorRate']) || \$errorRate < \$split['trainingErrorRate']) {` 188 ` \$split = [` 189 ` 'value' => \$threshold,` 190 ` 'operator' => \$operator,` 191 ` 'prob' => \$prob,` 192 ` 'column' => \$col,` 193 ` 'trainingErrorRate' => \$errorRate,` 194 ` ];` 195 ` }` 196 197 ` // Try other possible points one by one` 198 ` for (\$step = \$minValue; \$step <= \$maxValue; \$step += \$stepSize) {` 199 ` \$threshold = (float) \$step;` 200 ` [\$errorRate, \$prob] = \$this->calculateErrorRate(\$targets, \$threshold, \$operator, \$values);` 201 ` if (\$errorRate < \$split['trainingErrorRate']) {` 202 ` \$split = [` 203 ` 'value' => \$threshold,` 204 ` 'operator' => \$operator,` 205 ` 'prob' => \$prob,` 206 ` 'column' => \$col,` 207 ` 'trainingErrorRate' => \$errorRate,` 208 ` ];` 209 ` }` 210 ` }// for` 211 ` }` 212 213 ` return \$split;` 214 ` }` 215 216 ` protected function getBestNominalSplit(array \$samples, array \$targets, int \$col): array` 217 ` {` 218 ` \$values = array_column(\$samples, \$col);` 219 ` \$valueCounts = array_count_values(\$values);` 220 ` \$distinctVals = array_keys(\$valueCounts);` 221 222 ` \$split = [];` 223 224 ` foreach (['=', '!='] as \$operator) {` 225 ` foreach (\$distinctVals as \$val) {` 226 ` [\$errorRate, \$prob] = \$this->calculateErrorRate(\$targets, \$val, \$operator, \$values);` 227 ` if (!isset(\$split['trainingErrorRate']) || \$split['trainingErrorRate'] < \$errorRate) {` 228 ` \$split = [` 229 ` 'value' => \$val,` 230 ` 'operator' => \$operator,` 231 ` 'prob' => \$prob,` 232 ` 'column' => \$col,` 233 ` 'trainingErrorRate' => \$errorRate,` 234 ` ];` 235 ` }` 236 ` }` 237 ` }` 238 239 ` return \$split;` 240 ` }` 241 242 ` /**` 243 ` * Calculates the ratio of wrong predictions based on the new threshold` 244 ` * value given as the parameter` 245 ` */` 246 ` protected function calculateErrorRate(array \$targets, float \$threshold, string \$operator, array \$values): array` 247 ` {` 248 ` \$wrong = 0.0;` 249 ` \$prob = [];` 250 ` \$leftLabel = \$this->binaryLabels[0];` 251 ` \$rightLabel = \$this->binaryLabels[1];` 252 253 ` foreach (\$values as \$index => \$value) {` 254 ` if (Comparison::compare(\$value, \$threshold, \$operator)) {` 255 ` \$predicted = \$leftLabel;` 256 ` } else {` 257 ` \$predicted = \$rightLabel;` 258 ` }` 259 260 ` \$target = \$targets[\$index];` 261 ` if ((string) \$predicted != (string) \$targets[\$index]) {` 262 ` \$wrong += \$this->weights[\$index];` 263 ` }` 264 265 ` if (!isset(\$prob[\$predicted][\$target])) {` 266 ` \$prob[\$predicted][\$target] = 0;` 267 ` }` 268 269 ` ++\$prob[\$predicted][\$target];` 270 ` }` 271 272 ` // Calculate probabilities: Proportion of labels in each leaf` 273 ` \$dist = array_combine(\$this->binaryLabels, array_fill(0, 2, 0.0));` 274 ` foreach (\$prob as \$leaf => \$counts) {` 275 ` \$leafTotal = (float) array_sum(\$prob[\$leaf]);` 276 ` foreach (\$counts as \$label => \$count) {` 277 ` if ((string) \$leaf == (string) \$label) {` 278 ` \$dist[\$leaf] = \$count / \$leafTotal;` 279 ` }` 280 ` }` 281 ` }` 282 283 ` return [\$wrong / (float) array_sum(\$this->weights), \$dist];` 284 ` }` 285 286 ` /**` 287 ` * Returns the probability of the sample of belonging to the given label` 288 ` *` 289 ` * Probability of a sample is calculated as the proportion of the label` 290 ` * within the labels of the training samples in the decision node` 291 ` *` 292 ` * @param mixed \$label` 293 ` */` 294 ` protected function predictProbability(array \$sample, \$label): float` 295 ` {` 296 ` \$predicted = \$this->predictSampleBinary(\$sample);` 297 ` if ((string) \$predicted == (string) \$label) {` 298 ` return \$this->prob[\$label];` 299 ` }` 300 301 ` return 0.0;` 302 ` }` 303 304 ` /**` 305 ` * @return mixed` 306 ` */` 307 ` protected function predictSampleBinary(array \$sample)` 308 ` {` 309 ` if (Comparison::compare(\$sample[\$this->column], \$this->value, \$this->operator)) {` 310 ` return \$this->binaryLabels[0];` 311 ` }` 312 313 ` return \$this->binaryLabels[1];` 314 ` }` 315 316 ` protected function resetBinary(): void` 317 ` {` 318 ` }` 319 `}` 320