Completed
Pull Request — master (#36)
by
unknown
03:44 queued 01:01
created

DecisionTree::preprocess()   B

Complexity

Conditions 5
Paths 3

Size

Total Lines 23
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 23
rs 8.5906
c 0
b 0
f 0
cc 5
eloc 13
nc 3
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\Trainable;
9
use Phpml\Math\Statistic\Mean;
10
use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
11
12
class DecisionTree implements Classifier
0 ignored issues
show
Bug introduced by
Possible parse error: class missing opening or closing brace
Loading history...
13
{
14
    use Trainable, Predictable;
15
16
    const CONTINUOS = 1;
17
    const NOMINAL = 2;
18
19
    /**
20
     * @var array
21
     */
22
<<<<<<< HEAD
0 ignored issues
show
Bug introduced by
This code did not parse for me. Apparently, there is an error somewhere around this line:

Syntax error, unexpected T_SL, expecting T_FUNCTION or T_CONST
Loading history...
23
    private $samples = array();
24
=======
25
    private $samples = [];
26
>>>>>>> refs/remotes/php-ai/master
27
28
    /**
29
     * @var array
30
     */
31
    private $columnTypes;
32
<<<<<<< HEAD
33
    /**
34
     * @var array
35
     */
36
    private $labels = array();
37
=======
38
39
    /**
40
     * @var array
41
     */
42
    private $labels = [];
43
44
>>>>>>> refs/remotes/php-ai/master
45
    /**
46
     * @var int
47
     */
48
    private $featureCount = 0;
49
<<<<<<< HEAD
50
=======
51
52
>>>>>>> refs/remotes/php-ai/master
53
    /**
54
     * @var DecisionTreeLeaf
55
     */
56
    private $tree = null;
57
58
    /**
59
     * @var int
60
     */
61
    private $maxDepth;
62
63
    /**
64
     * @var int
65
     */
66
    public $actualDepth = 0;
67
68
    /**
69
     * @param int $maxDepth
70
     */
71
    public function __construct($maxDepth = 10)
72
    {
73
        $this->maxDepth = $maxDepth;
74
    }
75
    /**
76
     * @param array $samples
77
     * @param array $targets
78
     */
79
    public function train(array $samples, array $targets)
80
    {
81
<<<<<<< HEAD
82
        $this->featureCount = count($samples[0]);
83
        $this->columnTypes = $this->getColumnTypes($samples);
84
        $this->samples = $samples;
85
        $this->targets = $targets;
86
        $this->labels = array_keys(array_count_values($targets));
87
        $this->tree = $this->getSplitLeaf(range(0, count($samples) - 1));
88
=======
89
        $this->samples = array_merge($this->samples, $samples);
90
        $this->targets = array_merge($this->targets, $targets);
91
92
        $this->featureCount = count($this->samples[0]);
93
        $this->columnTypes = $this->getColumnTypes($this->samples);
94
        $this->labels = array_keys(array_count_values($this->targets));
95
        $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
96
>>>>>>> refs/remotes/php-ai/master
97
    }
98
99
    protected function getColumnTypes(array $samples)
100
    {
101
        $types = [];
102
        for ($i=0; $i<$this->featureCount; $i++) {
103
            $values = array_column($samples, $i);
104
            $isCategorical = $this->isCategoricalColumn($values);
105
            $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS;
106
        }
107
        return $types;
108
    }
109
110
    /**
111
     * @param null|array $records
112
     * @return DecisionTreeLeaf
113
     */
114
    protected function getSplitLeaf($records, $depth = 0)
115
    {
116
        $split = $this->getBestSplit($records);
117
        $split->level = $depth;
118
        if ($this->actualDepth < $depth) {
119
            $this->actualDepth = $depth;
120
        }
121
        $leftRecords = [];
122
        $rightRecords= [];
123
        $remainingTargets = [];
124
        $prevRecord = null;
125
        $allSame = true;
126
        foreach ($records as $recordNo) {
127
            $record = $this->samples[$recordNo];
128
            if ($prevRecord && $prevRecord != $record) {
129
                $allSame = false;
130
            }
131
            $prevRecord = $record;
132
            if ($split->evaluate($record)) {
133
                $leftRecords[] = $recordNo;
134
            } else {
135
                $rightRecords[]= $recordNo;
136
            }
137
            $target = $this->targets[$recordNo];
138
            if (! in_array($target, $remainingTargets)) {
139
                $remainingTargets[] = $target;
140
            }
141
        }
142
143
        if (count($remainingTargets) == 1 || $allSame || $depth >= $this->maxDepth) {
144
            $split->isTerminal = 1;
145
            $classes = array_count_values($remainingTargets);
146
            arsort($classes);
147
            $split->classValue = key($classes);
148
        } else {
149
            if ($leftRecords) {
150
                $split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);
151
            }
152
            if ($rightRecords) {
153
                $split->rightLeaf= $this->getSplitLeaf($rightRecords, $depth + 1);
154
            }
155
        }
156
        return $split;
157
    }
158
159
    /**
160
     * @param array $records
161
     * @return DecisionTreeLeaf[]
162
     */
163
    protected function getBestSplit($records)
164
    {
165
        $targets = array_intersect_key($this->targets, array_flip($records));
166
        $samples = array_intersect_key($this->samples, array_flip($records));
167
        $samples = array_combine($records, $this->preprocess($samples));
168
        $bestGiniVal = 1;
169
        $bestSplit = null;
170
        for ($i=0; $i<$this->featureCount; $i++) {
171
            $colValues = [];
172
            $baseValue = null;
173
            foreach ($samples as $index => $row) {
174
                $colValues[$index] = $row[$i];
175
                if ($baseValue === null) {
176
                    $baseValue = $row[$i];
177
                }
178
            }
179
            $gini = $this->getGiniIndex($baseValue, $colValues, $targets);
180
            if ($bestSplit == null || $bestGiniVal > $gini) {
181
                $split = new DecisionTreeLeaf();
182
                $split->value = $baseValue;
183
                $split->giniIndex = $gini;
184
                $split->columnIndex = $i;
185
                $split->records = $records;
186
                $bestSplit = $split;
187
                $bestGiniVal = $gini;
188
            }
189
        }
190
        return $bestSplit;
191
    }
192
193
    /**
194
     * @param string $baseValue
195
     * @param array $colValues
196
     * @param array $targets
197
     */
198
    public function getGiniIndex($baseValue, $colValues, $targets)
199
    {
200
        $countMatrix = [];
201
        foreach ($this->labels as $label) {
202
            $countMatrix[$label] = [0, 0];
203
        }
204
        foreach ($colValues as $index => $value) {
205
            $label = $targets[$index];
206
            $rowIndex = $value == $baseValue ? 0 : 1;
207
            $countMatrix[$label][$rowIndex]++;
208
        }
209
        $giniParts = [0, 0];
210
        for ($i=0; $i<=1; $i++) {
211
            $part = 0;
212
            $sum = array_sum(array_column($countMatrix, $i));
213
            if ($sum > 0) {
214
                foreach ($this->labels as $label) {
215
                    $part += pow($countMatrix[$label][$i] / floatval($sum), 2);
216
                }
217
            }
218
            $giniParts[$i] = (1 - $part) * $sum;
219
        }
220
        return array_sum($giniParts) / count($colValues);
221
    }
222
223
    /**
224
     * @param array $samples
225
     * @return array
226
     */
227
    protected function preprocess(array $samples)
228
    {
229
        // Detect and convert continuous data column values into
230
        // discrete values by using the median as a threshold value
231
<<<<<<< HEAD
232
        $columns = array();
233
=======
234
        $columns = [];
235
>>>>>>> refs/remotes/php-ai/master
236
        for ($i=0; $i<$this->featureCount; $i++) {
237
            $values = array_column($samples, $i);
238
            if ($this->columnTypes[$i] == self::CONTINUOS) {
239
                $median = Mean::median($values);
240
                foreach ($values as &$value) {
241
                    if ($value <= $median) {
242
                        $value = "<= $median";
243
                    } else {
244
                        $value = "> $median";
245
                    }
246
                }
247
            }
248
            $columns[] = $values;
249
        }
250
        // Below method is a strange yet very simple & efficient method
251
        // to get the transpose of a 2D array
252
        return array_map(null, ...$columns);
253
    }
254
255
    /**
256
     * @param array $columnValues
257
     * @return bool
258
     */
259
    protected function isCategoricalColumn(array $columnValues)
260
    {
261
        $count = count($columnValues);
262
        // There are two main indicators that *may* show whether a
263
        // column is composed of discrete set of values:
264
        // 1- Column may contain string values
265
        // 2- Number of unique values in the column is only a small fraction of
266
        //	  all values in that column (Lower than or equal to %20 of all values)
267
        $numericValues = array_filter($columnValues, 'is_numeric');
268
        if (count($numericValues) != $count) {
269
            return true;
270
        }
271
        $distinctValues = array_count_values($columnValues);
272
        if (count($distinctValues) <= $count / 5) {
273
            return true;
274
        }
275
        return false;
276
    }
277
278
    /**
279
     * @return string
280
     */
281
    public function getHtml()
282
    {
283
        return $this->tree->__toString();
284
    }
285
286
    /**
287
     * @param array $sample
288
     * @return mixed
289
     */
290
    protected function predictSample(array $sample)
291
    {
292
        $node = $this->tree;
293
        do {
294
            if ($node->isTerminal) {
295
                break;
296
            }
297
            if ($node->evaluate($sample)) {
298
                $node = $node->leftLeaf;
299
            } else {
300
                $node = $node->rightLeaf;
301
            }
302
        } while ($node);
303
<<<<<<< HEAD
304
        if ($node) {
305
            return $node->classValue;
306
        }
307
        return $this->labels[0];
308
=======
309
        return $node->classValue;
310
>>>>>>> refs/remotes/php-ai/master
311
    }
312
}
313