Passed
Push — master ( e83f7b...d953ef )
by Arkadiusz
03:28
created

src/Phpml/Classification/NaiveBayes.php (1 issue)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\Trainable;
9
use Phpml\Math\Statistic\Mean;
10
use Phpml\Math\Statistic\StandardDeviation;
11
12
class NaiveBayes implements Classifier
13
{
14
    use Trainable, Predictable;
15
16
    public const CONTINUOS = 1;
17
18
    public const NOMINAL = 2;
19
20
    public const EPSILON = 1e-10;
21
22
    /**
23
     * @var array
24
     */
25
    private $std = [];
26
27
    /**
28
     * @var array
29
     */
30
    private $mean = [];
31
32
    /**
33
     * @var array
34
     */
35
    private $discreteProb = [];
36
37
    /**
38
     * @var array
39
     */
40
    private $dataType = [];
41
42
    /**
43
     * @var array
44
     */
45
    private $p = [];
46
47
    /**
48
     * @var int
49
     */
50
    private $sampleCount = 0;
51
52
    /**
53
     * @var int
54
     */
55
    private $featureCount = 0;
56
57
    /**
58
     * @var array
59
     */
60
    private $labels = [];
61
62
    public function train(array $samples, array $targets): void
63
    {
64
        $this->samples = array_merge($this->samples, $samples);
65
        $this->targets = array_merge($this->targets, $targets);
66
        $this->sampleCount = count($this->samples);
67
        $this->featureCount = count($this->samples[0]);
68
69
        $labelCounts = array_count_values($this->targets);
70
        $this->labels = array_keys($labelCounts);
71
        foreach ($this->labels as $label) {
72
            $samples = $this->getSamplesByLabel($label);
73
            $this->p[$label] = count($samples) / $this->sampleCount;
74
            $this->calculateStatistics($label, $samples);
75
        }
76
    }
77
78
    /**
79
     * @return mixed
80
     */
81
    protected function predictSample(array $sample)
82
    {
83
        // Use NaiveBayes assumption for each label using:
84
        //	P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label)
0 ignored issues
show
Unused Code Comprehensibility introduced by
36% of this comment could be valid code. Did you maybe forget this after debugging?

Sometimes obsolete code just ends up commented out instead of removed. In this case it is better to remove the code once you have checked you do not need it.

The code might also have been commented out for debugging purposes. In this case it is vital that someone uncomments it again or your project may behave in very unexpected ways in production.

This check looks for comments that seem to be mostly valid code and reports them.

Loading history...
85
        // Then compare probability for each class to determine which label is most likely
86
        $predictions = [];
87
        foreach ($this->labels as $label) {
88
            $p = $this->p[$label];
89
            for ($i = 0; $i < $this->featureCount; ++$i) {
90
                $Plf = $this->sampleProbability($sample, $i, $label);
91
                $p += $Plf;
92
            }
93
94
            $predictions[$label] = $p;
95
        }
96
97
        arsort($predictions, SORT_NUMERIC);
98
        reset($predictions);
99
100
        return key($predictions);
101
    }
102
103
    /**
104
     * Calculates vital statistics for each label & feature. Stores these
105
     * values in private array in order to avoid repeated calculation
106
     */
107
    private function calculateStatistics(string $label, array $samples): void
108
    {
109
        $this->std[$label] = array_fill(0, $this->featureCount, 0);
110
        $this->mean[$label] = array_fill(0, $this->featureCount, 0);
111
        $this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
112
        $this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
113
        for ($i = 0; $i < $this->featureCount; ++$i) {
114
            // Get the values of nth column in the samples array
115
            // Mean::arithmetic is called twice, can be optimized
116
            $values = array_column($samples, $i);
117
            $numValues = count($values);
118
            // if the values contain non-numeric data,
119
            // then it should be treated as nominal/categorical/discrete column
120
            if ($values !== array_filter($values, 'is_numeric')) {
121
                $this->dataType[$label][$i] = self::NOMINAL;
122
                $this->discreteProb[$label][$i] = array_count_values($values);
123
                $db = &$this->discreteProb[$label][$i];
124
                $db = array_map(function ($el) use ($numValues) {
125
                    return $el / $numValues;
126
                }, $db);
127
            } else {
128
                $this->mean[$label][$i] = Mean::arithmetic($values);
129
                // Add epsilon in order to avoid zero stdev
130
                $this->std[$label][$i] = 1e-10 + StandardDeviation::population($values, false);
131
            }
132
        }
133
    }
134
135
    /**
136
     * Calculates the probability P(label|sample_n)
137
     */
138
    private function sampleProbability(array $sample, int $feature, string $label): float
139
    {
140
        $value = $sample[$feature];
141
        if ($this->dataType[$label][$feature] == self::NOMINAL) {
142
            if (!isset($this->discreteProb[$label][$feature][$value]) ||
143
                $this->discreteProb[$label][$feature][$value] == 0) {
144
                return self::EPSILON;
145
            }
146
147
            return $this->discreteProb[$label][$feature][$value];
148
        }
149
150
        $std = $this->std[$label][$feature];
151
        $mean = $this->mean[$label][$feature];
152
        // Calculate the probability density by use of normal/Gaussian distribution
153
        // Ref: https://en.wikipedia.org/wiki/Normal_distribution
154
        //
155
        // In order to avoid numerical errors because of small or zero values,
156
        // some libraries adopt taking log of calculations such as
157
        // scikit-learn did.
158
        // (See : https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/naive_bayes.py)
159
        $pdf = -0.5 * log(2.0 * pi() * $std * $std);
160
        $pdf -= 0.5 * pow($value - $mean, 2) / ($std * $std);
161
162
        return $pdf;
163
    }
164
165
    /**
166
     * Return samples belonging to specific label
167
     */
168
    private function getSamplesByLabel(string $label): array
169
    {
170
        $samples = [];
171
        for ($i = 0; $i < $this->sampleCount; ++$i) {
172
            if ($this->targets[$i] == $label) {
173
                $samples[] = $this->samples[$i];
174
            }
175
        }
176
177
        return $samples;
178
    }
179
}
180