These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more
1 | <?php |
||
2 | |||
3 | declare(strict_types=1); |
||
4 | |||
5 | namespace Phpml\Classification; |
||
6 | |||
7 | use Phpml\Helper\Predictable; |
||
8 | use Phpml\Helper\Trainable; |
||
9 | use Phpml\Math\Statistic\Mean; |
||
10 | use Phpml\Math\Statistic\StandardDeviation; |
||
11 | |||
12 | class NaiveBayes implements Classifier |
||
13 | { |
||
14 | use Trainable, Predictable; |
||
15 | |||
16 | public const CONTINUOS = 1; |
||
17 | |||
18 | public const NOMINAL = 2; |
||
19 | |||
20 | public const EPSILON = 1e-10; |
||
21 | |||
22 | /** |
||
23 | * @var array |
||
24 | */ |
||
25 | private $std = []; |
||
26 | |||
27 | /** |
||
28 | * @var array |
||
29 | */ |
||
30 | private $mean = []; |
||
31 | |||
32 | /** |
||
33 | * @var array |
||
34 | */ |
||
35 | private $discreteProb = []; |
||
36 | |||
37 | /** |
||
38 | * @var array |
||
39 | */ |
||
40 | private $dataType = []; |
||
41 | |||
42 | /** |
||
43 | * @var array |
||
44 | */ |
||
45 | private $p = []; |
||
46 | |||
47 | /** |
||
48 | * @var int |
||
49 | */ |
||
50 | private $sampleCount = 0; |
||
51 | |||
52 | /** |
||
53 | * @var int |
||
54 | */ |
||
55 | private $featureCount = 0; |
||
56 | |||
57 | /** |
||
58 | * @var array |
||
59 | */ |
||
60 | private $labels = []; |
||
61 | |||
62 | public function train(array $samples, array $targets): void |
||
63 | { |
||
64 | $this->samples = array_merge($this->samples, $samples); |
||
65 | $this->targets = array_merge($this->targets, $targets); |
||
66 | $this->sampleCount = count($this->samples); |
||
67 | $this->featureCount = count($this->samples[0]); |
||
68 | |||
69 | $labelCounts = array_count_values($this->targets); |
||
70 | $this->labels = array_keys($labelCounts); |
||
71 | foreach ($this->labels as $label) { |
||
72 | $samples = $this->getSamplesByLabel($label); |
||
73 | $this->p[$label] = count($samples) / $this->sampleCount; |
||
74 | $this->calculateStatistics($label, $samples); |
||
75 | } |
||
76 | } |
||
77 | |||
78 | /** |
||
79 | * @return mixed |
||
80 | */ |
||
81 | protected function predictSample(array $sample) |
||
82 | { |
||
83 | // Use NaiveBayes assumption for each label using: |
||
84 | // P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label) |
||
0 ignored issues
–
show
|
|||
85 | // Then compare probability for each class to determine which label is most likely |
||
86 | $predictions = []; |
||
87 | foreach ($this->labels as $label) { |
||
88 | $p = $this->p[$label]; |
||
89 | for ($i = 0; $i < $this->featureCount; ++$i) { |
||
90 | $Plf = $this->sampleProbability($sample, $i, $label); |
||
91 | $p += $Plf; |
||
92 | } |
||
93 | |||
94 | $predictions[$label] = $p; |
||
95 | } |
||
96 | |||
97 | arsort($predictions, SORT_NUMERIC); |
||
98 | reset($predictions); |
||
99 | |||
100 | return key($predictions); |
||
101 | } |
||
102 | |||
103 | /** |
||
104 | * Calculates vital statistics for each label & feature. Stores these |
||
105 | * values in private array in order to avoid repeated calculation |
||
106 | */ |
||
107 | private function calculateStatistics(string $label, array $samples): void |
||
108 | { |
||
109 | $this->std[$label] = array_fill(0, $this->featureCount, 0); |
||
110 | $this->mean[$label] = array_fill(0, $this->featureCount, 0); |
||
111 | $this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS); |
||
112 | $this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS); |
||
113 | for ($i = 0; $i < $this->featureCount; ++$i) { |
||
114 | // Get the values of nth column in the samples array |
||
115 | // Mean::arithmetic is called twice, can be optimized |
||
116 | $values = array_column($samples, $i); |
||
117 | $numValues = count($values); |
||
118 | // if the values contain non-numeric data, |
||
119 | // then it should be treated as nominal/categorical/discrete column |
||
120 | if ($values !== array_filter($values, 'is_numeric')) { |
||
121 | $this->dataType[$label][$i] = self::NOMINAL; |
||
122 | $this->discreteProb[$label][$i] = array_count_values($values); |
||
123 | $db = &$this->discreteProb[$label][$i]; |
||
124 | $db = array_map(function ($el) use ($numValues) { |
||
125 | return $el / $numValues; |
||
126 | }, $db); |
||
127 | } else { |
||
128 | $this->mean[$label][$i] = Mean::arithmetic($values); |
||
129 | // Add epsilon in order to avoid zero stdev |
||
130 | $this->std[$label][$i] = 1e-10 + StandardDeviation::population($values, false); |
||
131 | } |
||
132 | } |
||
133 | } |
||
134 | |||
135 | /** |
||
136 | * Calculates the probability P(label|sample_n) |
||
137 | */ |
||
138 | private function sampleProbability(array $sample, int $feature, string $label): float |
||
139 | { |
||
140 | $value = $sample[$feature]; |
||
141 | if ($this->dataType[$label][$feature] == self::NOMINAL) { |
||
142 | if (!isset($this->discreteProb[$label][$feature][$value]) || |
||
143 | $this->discreteProb[$label][$feature][$value] == 0) { |
||
144 | return self::EPSILON; |
||
145 | } |
||
146 | |||
147 | return $this->discreteProb[$label][$feature][$value]; |
||
148 | } |
||
149 | |||
150 | $std = $this->std[$label][$feature]; |
||
151 | $mean = $this->mean[$label][$feature]; |
||
152 | // Calculate the probability density by use of normal/Gaussian distribution |
||
153 | // Ref: https://en.wikipedia.org/wiki/Normal_distribution |
||
154 | // |
||
155 | // In order to avoid numerical errors because of small or zero values, |
||
156 | // some libraries adopt taking log of calculations such as |
||
157 | // scikit-learn did. |
||
158 | // (See : https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/naive_bayes.py) |
||
159 | $pdf = -0.5 * log(2.0 * pi() * $std * $std); |
||
160 | $pdf -= 0.5 * pow($value - $mean, 2) / ($std * $std); |
||
161 | |||
162 | return $pdf; |
||
163 | } |
||
164 | |||
165 | /** |
||
166 | * Return samples belonging to specific label |
||
167 | */ |
||
168 | private function getSamplesByLabel(string $label): array |
||
169 | { |
||
170 | $samples = []; |
||
171 | for ($i = 0; $i < $this->sampleCount; ++$i) { |
||
172 | if ($this->targets[$i] == $label) { |
||
173 | $samples[] = $this->samples[$i]; |
||
174 | } |
||
175 | } |
||
176 | |||
177 | return $samples; |
||
178 | } |
||
179 | } |
||
180 |
Sometimes obsolete code just ends up commented out instead of removed. In this case it is better to remove the code once you have checked you do not need it.
The code might also have been commented out for debugging purposes. In this case it is vital that someone uncomments it again or your project may behave in very unexpected ways in production.
This check looks for comments that seem to be mostly valid code and reports them.