|
1
|
|
|
<?php |
|
2
|
|
|
|
|
3
|
|
|
declare(strict_types=1); |
|
4
|
|
|
|
|
5
|
|
|
namespace Phpml\Classification; |
|
6
|
|
|
|
|
7
|
|
|
use Phpml\Helper\Predictable; |
|
8
|
|
|
use Phpml\Helper\Trainable; |
|
9
|
|
|
use Phpml\Math\Statistic\Mean; |
|
10
|
|
|
use Phpml\Math\Statistic\StandardDeviation; |
|
11
|
|
|
|
|
12
|
|
|
class NaiveBayes implements Classifier |
|
13
|
|
|
{ |
|
14
|
|
|
use Trainable, Predictable; |
|
15
|
|
|
const CONTINUOS = 1; |
|
16
|
|
|
const NOMINAL = 2; |
|
17
|
|
|
const SMALL_VALUE = 1e-32; |
|
18
|
|
|
private $std = array(); |
|
19
|
|
|
private $mean= array(); |
|
20
|
|
|
private $discreteProb = array(); |
|
21
|
|
|
private $dataType = array(); |
|
22
|
|
|
private $p = array(); |
|
23
|
|
|
private $sampleCount = 0; |
|
24
|
|
|
private $featureCount = 0; |
|
25
|
|
|
private $labels = array(); |
|
26
|
|
|
public function train(array $samples, array $targets) |
|
27
|
|
|
{ |
|
28
|
|
|
$this->samples = $samples; |
|
29
|
|
|
$this->targets = $targets; |
|
30
|
|
|
$this->sampleCount = count($samples); |
|
31
|
|
|
$this->featureCount = count($samples[0]); |
|
32
|
|
|
// Get distinct targets |
|
33
|
|
|
$this->labels = $targets; |
|
34
|
|
|
array_unique($this->labels); |
|
35
|
|
|
foreach ($this->labels as $label) { |
|
36
|
|
|
$samples = $this->getSamplesByLabel($label); |
|
37
|
|
|
$this->p[$label] = count($samples) / $this->sampleCount; |
|
38
|
|
|
$this->calculateStatistics($label, $samples); |
|
39
|
|
|
} |
|
40
|
|
|
} |
|
41
|
|
|
|
|
42
|
|
|
/** |
|
43
|
|
|
* Calculates vital statistics for each label & feature. Stores these |
|
44
|
|
|
* values in private array in order to avoid repeated calculation |
|
45
|
|
|
* @param string $label |
|
46
|
|
|
* @param array $samples |
|
47
|
|
|
*/ |
|
48
|
|
|
private function calculateStatistics($label, $samples) |
|
49
|
|
|
{ |
|
50
|
|
|
$this->std[$label] = array_fill(0, $this->featureCount, 0); |
|
51
|
|
|
$this->mean[$label]= array_fill(0, $this->featureCount, 0); |
|
52
|
|
|
$this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS); |
|
53
|
|
|
$this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS); |
|
54
|
|
|
for ($i=0; $i<$this->featureCount; $i++) { |
|
55
|
|
|
// Get the values of nth column in the samples array |
|
56
|
|
|
// Mean::arithmetic is called twice, can be optimized |
|
57
|
|
|
$values = array_column($samples, $i); |
|
58
|
|
|
$numValues = count($values); |
|
59
|
|
|
// if the values contain non-numeric data, |
|
60
|
|
|
// then it should be treated as nominal/categorical/discrete column |
|
61
|
|
|
if ($values !== array_filter($values, 'is_numeric')) { |
|
62
|
|
|
$this->dataType[$label][$i] = self::NOMINAL; |
|
63
|
|
|
$this->discreteProb[$label][$i] = array_count_values($values); |
|
64
|
|
|
$db = &$this->discreteProb[$label][$i]; |
|
65
|
|
|
$db = array_map(function ($el) use ($numValues) { |
|
66
|
|
|
return $el / $numValues; |
|
67
|
|
|
}, $db); |
|
68
|
|
|
} else { |
|
69
|
|
|
$this->mean[$label][$i] = Mean::arithmetic($values); |
|
70
|
|
|
$this->std[$label][$i] = StandardDeviation::population($values, false); |
|
71
|
|
|
} |
|
72
|
|
|
} |
|
73
|
|
|
} |
|
74
|
|
|
|
|
75
|
|
|
/** |
|
76
|
|
|
* Calculates the probability P(label|sample_n) assuming |
|
77
|
|
|
* the feature is a continuous value |
|
78
|
|
|
* @param array $sample |
|
79
|
|
|
* @param int $feature |
|
80
|
|
|
* @param string $label |
|
81
|
|
|
*/ |
|
82
|
|
|
private function sampleProbability($sample, $feature, $label) |
|
83
|
|
|
{ |
|
84
|
|
|
$value = $sample[$feature]; |
|
85
|
|
|
if ($this->dataType[$label][$feature] == self::NOMINAL) { |
|
86
|
|
|
if (! isset($this->discreteProb[$label][$feature][$value]) || |
|
87
|
|
|
$this->discreteProb[$label][$feature][$value] == 0) { |
|
88
|
|
|
return self::SMALL_VALUE; |
|
89
|
|
|
} |
|
90
|
|
|
return $this->discreteProb[$label][$feature][$value]; |
|
91
|
|
|
} |
|
92
|
|
|
$std = $this->std[$label][$feature] ; |
|
93
|
|
|
$mean= $this->mean[$label][$feature]; |
|
94
|
|
|
$std2 = $std * $std; |
|
95
|
|
|
if ($std2 == 0) { |
|
96
|
|
|
$std2 = self::SMALL_VALUE; |
|
97
|
|
|
} |
|
98
|
|
|
// Calculate the probability density by use of normal/Gaussian distribution |
|
99
|
|
|
// Ref: https://en.wikipedia.org/wiki/Normal_distribution |
|
100
|
|
|
return (1 / sqrt(2 * $std2 * pi())) * exp(- pow($value - $mean, 2) / (2 * $std2)); |
|
101
|
|
|
} |
|
102
|
|
|
|
|
103
|
|
|
/** |
|
104
|
|
|
* Return samples belonging to specific label |
|
105
|
|
|
* @param string $label |
|
106
|
|
|
* @return array |
|
107
|
|
|
*/ |
|
108
|
|
|
private function getSamplesByLabel($label) |
|
109
|
|
|
{ |
|
110
|
|
|
$samples = array(); |
|
111
|
|
|
for ($i=0; $i<$this->sampleCount; $i++) { |
|
112
|
|
|
if ($this->targets[$i] == $label) { |
|
113
|
|
|
$samples[] = $this->samples[$i]; |
|
114
|
|
|
} |
|
115
|
|
|
} |
|
116
|
|
|
return $samples; |
|
117
|
|
|
} |
|
118
|
|
|
|
|
119
|
|
|
/** |
|
120
|
|
|
* @param array $sample |
|
121
|
|
|
* @return mixed |
|
122
|
|
|
*/ |
|
123
|
|
|
protected function predictSample(array $sample) |
|
124
|
|
|
{ |
|
125
|
|
|
$isArray = is_array($sample[0]); |
|
126
|
|
|
$samples = $sample; |
|
127
|
|
|
if (!$isArray) { |
|
128
|
|
|
$samples = array($sample); |
|
129
|
|
|
} |
|
130
|
|
|
$samplePredictions = array(); |
|
131
|
|
|
foreach ($samples as $sample) { |
|
132
|
|
|
// Use NaiveBayes assumption for each label using: |
|
133
|
|
|
// P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label) |
|
|
|
|
|
|
134
|
|
|
// Then compare probability for each class to determine which is most likely |
|
135
|
|
|
$predictions = array(); |
|
136
|
|
|
foreach ($this->labels as $label) { |
|
137
|
|
|
$p = $this->p[$label]; |
|
138
|
|
|
for ($i=0; $i<$this->featureCount; $i++) { |
|
139
|
|
|
$Plf = $this->sampleProbability($sample, $i, $label); |
|
140
|
|
|
// Correct the value for small and zero values |
|
141
|
|
|
if (is_nan($Plf) || $Plf < self::SMALL_VALUE) { |
|
142
|
|
|
$Plf = self::SMALL_VALUE; |
|
143
|
|
|
} |
|
144
|
|
|
$p *= $Plf; |
|
145
|
|
|
} |
|
146
|
|
|
$predictions[$label] = $p; |
|
147
|
|
|
} |
|
148
|
|
|
arsort($predictions, SORT_NUMERIC); |
|
149
|
|
|
reset($predictions); |
|
150
|
|
|
$samplePredictions[] = key($predictions); |
|
151
|
|
|
} |
|
152
|
|
|
if (! $isArray) { |
|
153
|
|
|
return $samplePredictions[0]; |
|
154
|
|
|
} |
|
155
|
|
|
return $samplePredictions; |
|
156
|
|
|
} |
|
157
|
|
|
} |
|
158
|
|
|
|
Sometimes obsolete code just ends up commented out instead of removed. In this case it is better to remove the code once you have checked you do not need it.
The code might also have been commented out for debugging purposes. In this case it is vital that someone uncomments it again or your project may behave in very unexpected ways in production.
This check looks for comments that seem to be mostly valid code and reports them.