NaiveBayes   A
last analyzed

Complexity

Total Complexity 16

Size/Duplication

Total Lines 170
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 16
eloc 68
dl 0
loc 170
rs 10
c 0
b 0
f 0

5 Methods

Rating   Name   Duplication   Size   Complexity  
A getSamplesByLabel() 0 10 3
A train() 0 12 2
A sampleProbability() 0 29 5
A predictSample() 0 20 3
A calculateStatistics() 0 24 3
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Exception\InvalidArgumentException;
8
use Phpml\Helper\Predictable;
9
use Phpml\Helper\Trainable;
10
use Phpml\Math\Statistic\Mean;
11
use Phpml\Math\Statistic\StandardDeviation;
12
13
class NaiveBayes implements Classifier
14
{
15
    use Trainable;
16
    use Predictable;
17
18
    public const CONTINUOS = 1;
19
20
    public const NOMINAL = 2;
21
22
    public const EPSILON = 1e-10;
23
24
    /**
25
     * @var array
26
     */
27
    private $std = [];
28
29
    /**
30
     * @var array
31
     */
32
    private $mean = [];
33
34
    /**
35
     * @var array
36
     */
37
    private $discreteProb = [];
38
39
    /**
40
     * @var array
41
     */
42
    private $dataType = [];
43
44
    /**
45
     * @var array
46
     */
47
    private $p = [];
48
49
    /**
50
     * @var int
51
     */
52
    private $sampleCount = 0;
53
54
    /**
55
     * @var int
56
     */
57
    private $featureCount = 0;
58
59
    /**
60
     * @var array
61
     */
62
    private $labels = [];
63
64
    public function train(array $samples, array $targets): void
65
    {
66
        $this->samples = array_merge($this->samples, $samples);
67
        $this->targets = array_merge($this->targets, $targets);
68
        $this->sampleCount = count($this->samples);
69
        $this->featureCount = count($this->samples[0]);
70
71
        $this->labels = array_map('strval', array_flip(array_flip($this->targets)));
72
        foreach ($this->labels as $label) {
73
            $samples = $this->getSamplesByLabel($label);
74
            $this->p[$label] = count($samples) / $this->sampleCount;
75
            $this->calculateStatistics($label, $samples);
76
        }
77
    }
78
79
    /**
80
     * @return mixed
81
     */
82
    protected function predictSample(array $sample)
83
    {
84
        // Use NaiveBayes assumption for each label using:
85
        //	P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label)
86
        // Then compare probability for each class to determine which label is most likely
87
        $predictions = [];
88
        foreach ($this->labels as $label) {
89
            $p = $this->p[$label];
90
            for ($i = 0; $i < $this->featureCount; ++$i) {
91
                $Plf = $this->sampleProbability($sample, $i, $label);
92
                $p += $Plf;
93
            }
94
95
            $predictions[$label] = $p;
96
        }
97
98
        arsort($predictions, SORT_NUMERIC);
99
        reset($predictions);
100
101
        return key($predictions);
102
    }
103
104
    /**
105
     * Calculates vital statistics for each label & feature. Stores these
106
     * values in private array in order to avoid repeated calculation
107
     */
108
    private function calculateStatistics(string $label, array $samples): void
109
    {
110
        $this->std[$label] = array_fill(0, $this->featureCount, 0);
111
        $this->mean[$label] = array_fill(0, $this->featureCount, 0);
112
        $this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
113
        $this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
114
        for ($i = 0; $i < $this->featureCount; ++$i) {
115
            // Get the values of nth column in the samples array
116
            // Mean::arithmetic is called twice, can be optimized
117
            $values = array_column($samples, $i);
118
            $numValues = count($values);
119
            // if the values contain non-numeric data,
120
            // then it should be treated as nominal/categorical/discrete column
121
            if ($values !== array_filter($values, 'is_numeric')) {
122
                $this->dataType[$label][$i] = self::NOMINAL;
123
                $this->discreteProb[$label][$i] = array_count_values($values);
124
                $db = &$this->discreteProb[$label][$i];
125
                $db = array_map(function ($el) use ($numValues) {
126
                    return $el / $numValues;
127
                }, $db);
128
            } else {
129
                $this->mean[$label][$i] = Mean::arithmetic($values);
130
                // Add epsilon in order to avoid zero stdev
131
                $this->std[$label][$i] = 1e-10 + StandardDeviation::population($values, false);
132
            }
133
        }
134
    }
135
136
    /**
137
     * Calculates the probability P(label|sample_n)
138
     */
139
    private function sampleProbability(array $sample, int $feature, string $label): float
140
    {
141
        if (!isset($sample[$feature])) {
142
            throw new InvalidArgumentException('Missing feature. All samples must have equal number of features');
143
        }
144
145
        $value = $sample[$feature];
146
        if ($this->dataType[$label][$feature] == self::NOMINAL) {
147
            if (!isset($this->discreteProb[$label][$feature][$value]) ||
148
                $this->discreteProb[$label][$feature][$value] == 0) {
149
                return self::EPSILON;
150
            }
151
152
            return $this->discreteProb[$label][$feature][$value];
153
        }
154
155
        $std = $this->std[$label][$feature];
156
        $mean = $this->mean[$label][$feature];
157
        // Calculate the probability density by use of normal/Gaussian distribution
158
        // Ref: https://en.wikipedia.org/wiki/Normal_distribution
159
        //
160
        // In order to avoid numerical errors because of small or zero values,
161
        // some libraries adopt taking log of calculations such as
162
        // scikit-learn did.
163
        // (See : https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/naive_bayes.py)
164
        $pdf = -0.5 * log(2.0 * M_PI * $std * $std);
165
        $pdf -= 0.5 * (($value - $mean) ** 2) / ($std * $std);
166
167
        return $pdf;
168
    }
169
170
    /**
171
     * Return samples belonging to specific label
172
     */
173
    private function getSamplesByLabel(string $label): array
174
    {
175
        $samples = [];
176
        for ($i = 0; $i < $this->sampleCount; ++$i) {
177
            if ($this->targets[$i] == $label) {
178
                $samples[] = $this->samples[$i];
179
            }
180
        }
181
182
        return $samples;
183
    }
184
}
185