1
|
|
|
<?php |
2
|
|
|
|
3
|
|
|
declare(strict_types=1); |
4
|
|
|
|
5
|
|
|
namespace Phpml\Classification; |
6
|
|
|
|
7
|
|
|
use Phpml\Helper\Predictable; |
8
|
|
|
use Phpml\Helper\Trainable; |
9
|
|
|
use Phpml\Math\Statistic\Mean; |
10
|
|
|
use Phpml\Classification\DecisionTree\DecisionTreeLeaf; |
11
|
|
|
|
12
|
|
|
class DecisionTree implements Classifier |
13
|
|
|
{ |
14
|
|
|
use Trainable, Predictable; |
15
|
|
|
|
16
|
|
|
const CONTINUOS = 1; |
17
|
|
|
const NOMINAL = 2; |
18
|
|
|
|
19
|
|
|
/** |
20
|
|
|
* @var array |
21
|
|
|
*/ |
22
|
|
|
private $samples = []; |
23
|
|
|
|
24
|
|
|
/** |
25
|
|
|
* @var array |
26
|
|
|
*/ |
27
|
|
|
private $columnTypes; |
28
|
|
|
|
29
|
|
|
/** |
30
|
|
|
* @var array |
31
|
|
|
*/ |
32
|
|
|
private $labels = []; |
33
|
|
|
|
34
|
|
|
/** |
35
|
|
|
* @var int |
36
|
|
|
*/ |
37
|
|
|
private $featureCount = 0; |
38
|
|
|
|
39
|
|
|
/** |
40
|
|
|
* @var DecisionTreeLeaf |
41
|
|
|
*/ |
42
|
|
|
private $tree = null; |
43
|
|
|
|
44
|
|
|
/** |
45
|
|
|
* @var int |
46
|
|
|
*/ |
47
|
|
|
private $maxDepth; |
48
|
|
|
|
49
|
|
|
/** |
50
|
|
|
* @var int |
51
|
|
|
*/ |
52
|
|
|
public $actualDepth = 0; |
53
|
|
|
|
54
|
|
|
/** |
55
|
|
|
* @var int |
56
|
|
|
*/ |
57
|
|
|
private $numUsableFeatures = 0; |
58
|
|
|
|
59
|
|
|
/** |
60
|
|
|
* @param int $maxDepth |
61
|
|
|
*/ |
62
|
|
|
public function __construct($maxDepth = 10) |
63
|
|
|
{ |
64
|
|
|
$this->maxDepth = $maxDepth; |
65
|
|
|
} |
66
|
|
|
/** |
67
|
|
|
* @param array $samples |
68
|
|
|
* @param array $targets |
69
|
|
|
*/ |
70
|
|
|
public function train(array $samples, array $targets) |
71
|
|
|
{ |
72
|
|
|
$this->samples = array_merge($this->samples, $samples); |
73
|
|
|
$this->targets = array_merge($this->targets, $targets); |
74
|
|
|
|
75
|
|
|
$this->featureCount = count($this->samples[0]); |
76
|
|
|
$this->columnTypes = $this->getColumnTypes($this->samples); |
77
|
|
|
$this->labels = array_keys(array_count_values($this->targets)); |
78
|
|
|
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1)); |
79
|
|
|
} |
80
|
|
|
|
81
|
|
|
protected function getColumnTypes(array $samples) |
82
|
|
|
{ |
83
|
|
|
$types = []; |
84
|
|
|
for ($i=0; $i<$this->featureCount; $i++) { |
85
|
|
|
$values = array_column($samples, $i); |
86
|
|
|
$isCategorical = $this->isCategoricalColumn($values); |
87
|
|
|
$types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS; |
88
|
|
|
} |
89
|
|
|
return $types; |
90
|
|
|
} |
91
|
|
|
|
92
|
|
|
/** |
93
|
|
|
* @param null|array $records |
94
|
|
|
* @return DecisionTreeLeaf |
95
|
|
|
*/ |
96
|
|
|
protected function getSplitLeaf($records, $depth = 0) |
97
|
|
|
{ |
98
|
|
|
$split = $this->getBestSplit($records); |
|
|
|
|
99
|
|
|
$split->level = $depth; |
100
|
|
|
if ($this->actualDepth < $depth) { |
101
|
|
|
$this->actualDepth = $depth; |
102
|
|
|
} |
103
|
|
|
$leftRecords = []; |
104
|
|
|
$rightRecords= []; |
105
|
|
|
$remainingTargets = []; |
106
|
|
|
$prevRecord = null; |
107
|
|
|
$allSame = true; |
108
|
|
|
foreach ($records as $recordNo) { |
|
|
|
|
109
|
|
|
$record = $this->samples[$recordNo]; |
110
|
|
|
if ($prevRecord && $prevRecord != $record) { |
111
|
|
|
$allSame = false; |
112
|
|
|
} |
113
|
|
|
$prevRecord = $record; |
114
|
|
|
if ($split->evaluate($record)) { |
115
|
|
|
$leftRecords[] = $recordNo; |
116
|
|
|
} else { |
117
|
|
|
$rightRecords[]= $recordNo; |
118
|
|
|
} |
119
|
|
|
$target = $this->targets[$recordNo]; |
120
|
|
|
if (! in_array($target, $remainingTargets)) { |
121
|
|
|
$remainingTargets[] = $target; |
122
|
|
|
} |
123
|
|
|
} |
124
|
|
|
|
125
|
|
|
if (count($remainingTargets) == 1 || $allSame || $depth >= $this->maxDepth) { |
126
|
|
|
$split->isTerminal = 1; |
|
|
|
|
127
|
|
|
$classes = array_count_values($remainingTargets); |
128
|
|
|
arsort($classes); |
129
|
|
|
$split->classValue = key($classes); |
130
|
|
|
} else { |
131
|
|
|
if ($leftRecords) { |
|
|
|
|
132
|
|
|
$split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1); |
133
|
|
|
} |
134
|
|
|
if ($rightRecords) { |
|
|
|
|
135
|
|
|
$split->rightLeaf= $this->getSplitLeaf($rightRecords, $depth + 1); |
136
|
|
|
} |
137
|
|
|
} |
138
|
|
|
return $split; |
139
|
|
|
} |
140
|
|
|
|
141
|
|
|
/** |
142
|
|
|
* @param array $records |
143
|
|
|
* @return DecisionTreeLeaf[] |
144
|
|
|
*/ |
145
|
|
|
protected function getBestSplit($records) |
146
|
|
|
{ |
147
|
|
|
$targets = array_intersect_key($this->targets, array_flip($records)); |
148
|
|
|
$samples = array_intersect_key($this->samples, array_flip($records)); |
149
|
|
|
$samples = array_combine($records, $this->preprocess($samples)); |
150
|
|
|
$bestGiniVal = 1; |
151
|
|
|
$bestSplit = null; |
152
|
|
|
$features = $this->getSelectedFeatures(); |
153
|
|
|
foreach ($features as $i) { |
154
|
|
|
$colValues = []; |
155
|
|
|
foreach ($samples as $index => $row) { |
156
|
|
|
$colValues[$index] = $row[$i]; |
157
|
|
|
} |
158
|
|
|
$counts = array_count_values($colValues); |
159
|
|
|
arsort($counts); |
160
|
|
|
$baseValue = key($counts); |
161
|
|
|
$gini = $this->getGiniIndex($baseValue, $colValues, $targets); |
162
|
|
|
if ($bestSplit == null || $bestGiniVal > $gini) { |
163
|
|
|
$split = new DecisionTreeLeaf(); |
164
|
|
|
$split->value = $baseValue; |
165
|
|
|
$split->giniIndex = $gini; |
|
|
|
|
166
|
|
|
$split->columnIndex = $i; |
167
|
|
|
$split->records = $records; |
168
|
|
|
$bestSplit = $split; |
169
|
|
|
$bestGiniVal = $gini; |
170
|
|
|
} |
171
|
|
|
} |
172
|
|
|
return $bestSplit; |
173
|
|
|
} |
174
|
|
|
|
175
|
|
|
/** |
176
|
|
|
* @return array |
177
|
|
|
*/ |
178
|
|
|
protected function getSelectedFeatures() |
179
|
|
|
{ |
180
|
|
|
$allFeatures = range(0, $this->featureCount - 1); |
181
|
|
|
if ($this->numUsableFeatures == 0) { |
182
|
|
|
return $allFeatures; |
183
|
|
|
} |
184
|
|
|
|
185
|
|
|
$numFeatures = $this->numUsableFeatures; |
186
|
|
|
if ($numFeatures > $this->featureCount) { |
187
|
|
|
$numFeatures = $this->featureCount; |
188
|
|
|
} |
189
|
|
|
shuffle($allFeatures); |
190
|
|
|
$selectedFeatures = array_slice($allFeatures, 0, $numFeatures, false); |
191
|
|
|
sort($selectedFeatures); |
192
|
|
|
|
193
|
|
|
return $selectedFeatures; |
194
|
|
|
} |
195
|
|
|
|
196
|
|
|
/** |
197
|
|
|
* @param string $baseValue |
198
|
|
|
* @param array $colValues |
199
|
|
|
* @param array $targets |
200
|
|
|
*/ |
201
|
|
|
public function getGiniIndex($baseValue, $colValues, $targets) |
202
|
|
|
{ |
203
|
|
|
$countMatrix = []; |
204
|
|
|
foreach ($this->labels as $label) { |
205
|
|
|
$countMatrix[$label] = [0, 0]; |
206
|
|
|
} |
207
|
|
|
foreach ($colValues as $index => $value) { |
208
|
|
|
$label = $targets[$index]; |
209
|
|
|
$rowIndex = $value == $baseValue ? 0 : 1; |
210
|
|
|
$countMatrix[$label][$rowIndex]++; |
211
|
|
|
} |
212
|
|
|
$giniParts = [0, 0]; |
213
|
|
|
for ($i=0; $i<=1; $i++) { |
214
|
|
|
$part = 0; |
215
|
|
|
$sum = array_sum(array_column($countMatrix, $i)); |
216
|
|
|
if ($sum > 0) { |
217
|
|
|
foreach ($this->labels as $label) { |
218
|
|
|
$part += pow($countMatrix[$label][$i] / floatval($sum), 2); |
219
|
|
|
} |
220
|
|
|
} |
221
|
|
|
$giniParts[$i] = (1 - $part) * $sum; |
222
|
|
|
} |
223
|
|
|
return array_sum($giniParts) / count($colValues); |
224
|
|
|
} |
225
|
|
|
|
226
|
|
|
/** |
227
|
|
|
* @param array $samples |
228
|
|
|
* @return array |
229
|
|
|
*/ |
230
|
|
|
protected function preprocess(array $samples) |
231
|
|
|
{ |
232
|
|
|
// Detect and convert continuous data column values into |
233
|
|
|
// discrete values by using the median as a threshold value |
234
|
|
|
$columns = []; |
235
|
|
|
for ($i=0; $i<$this->featureCount; $i++) { |
236
|
|
|
$values = array_column($samples, $i); |
237
|
|
|
if ($this->columnTypes[$i] == self::CONTINUOS) { |
238
|
|
|
$median = Mean::median($values); |
239
|
|
|
foreach ($values as &$value) { |
240
|
|
|
if ($value <= $median) { |
241
|
|
|
$value = "<= $median"; |
242
|
|
|
} else { |
243
|
|
|
$value = "> $median"; |
244
|
|
|
} |
245
|
|
|
} |
246
|
|
|
} |
247
|
|
|
$columns[] = $values; |
248
|
|
|
} |
249
|
|
|
// Below method is a strange yet very simple & efficient method |
250
|
|
|
// to get the transpose of a 2D array |
251
|
|
|
return array_map(null, ...$columns); |
252
|
|
|
} |
253
|
|
|
|
254
|
|
|
/** |
255
|
|
|
* @param array $columnValues |
256
|
|
|
* @return bool |
257
|
|
|
*/ |
258
|
|
|
protected function isCategoricalColumn(array $columnValues) |
259
|
|
|
{ |
260
|
|
|
$count = count($columnValues); |
261
|
|
|
// There are two main indicators that *may* show whether a |
262
|
|
|
// column is composed of discrete set of values: |
263
|
|
|
// 1- Column may contain string values |
264
|
|
|
// 2- Number of unique values in the column is only a small fraction of |
265
|
|
|
// all values in that column (Lower than or equal to %20 of all values) |
266
|
|
|
$numericValues = array_filter($columnValues, 'is_numeric'); |
267
|
|
|
if (count($numericValues) != $count) { |
268
|
|
|
return true; |
269
|
|
|
} |
270
|
|
|
$distinctValues = array_count_values($columnValues); |
271
|
|
|
if (count($distinctValues) <= $count / 5) { |
272
|
|
|
return true; |
273
|
|
|
} |
274
|
|
|
return false; |
275
|
|
|
} |
276
|
|
|
|
277
|
|
|
/** |
278
|
|
|
* This method is used to set number of columns to be used |
279
|
|
|
* when deciding a split at an internal node of the tree. <br> |
280
|
|
|
* If the value is given 0, then all features are used (default behaviour), |
281
|
|
|
* otherwise the given value will be used as a maximum for number of columns |
282
|
|
|
* randomly selected for each split operation. |
283
|
|
|
* |
284
|
|
|
* @param int $numFeatures |
285
|
|
|
* @return $this |
286
|
|
|
* @throws Exception |
287
|
|
|
*/ |
288
|
|
|
public function setNumFeatures(int $numFeatures) |
289
|
|
|
{ |
290
|
|
|
if ($numFeatures < 0) { |
291
|
|
|
throw new \Exception("Selected column count should be greater or equal to zero"); |
292
|
|
|
} |
293
|
|
|
|
294
|
|
|
$this->numUsableFeatures = $numFeatures; |
295
|
|
|
return $this; |
296
|
|
|
} |
297
|
|
|
|
298
|
|
|
/** |
299
|
|
|
* @return string |
300
|
|
|
*/ |
301
|
|
|
public function getHtml() |
302
|
|
|
{ |
303
|
|
|
return $this->tree->__toString(); |
304
|
|
|
} |
305
|
|
|
|
306
|
|
|
/** |
307
|
|
|
* @param array $sample |
308
|
|
|
* @return mixed |
309
|
|
|
*/ |
310
|
|
|
protected function predictSample(array $sample) |
311
|
|
|
{ |
312
|
|
|
$node = $this->tree; |
313
|
|
|
do { |
314
|
|
|
if ($node->isTerminal) { |
315
|
|
|
break; |
316
|
|
|
} |
317
|
|
|
if ($node->evaluate($sample)) { |
318
|
|
|
$node = $node->leftLeaf; |
319
|
|
|
} else { |
320
|
|
|
$node = $node->rightLeaf; |
321
|
|
|
} |
322
|
|
|
} while ($node); |
323
|
|
|
|
324
|
|
|
return $node ? $node->classValue : $this->labels[0]; |
325
|
|
|
} |
326
|
|
|
} |
327
|
|
|
|
This check looks at variables that have been passed in as parameters and are passed out again to other methods.
If the outgoing method call has stricter type requirements than the method itself, an issue is raised.
An additional type check may prevent trouble.