Completed
Push — master ( 95fc13...87396e )
by Arkadiusz
02:45
created

NaiveBayes::predictSample()   A

Complexity

Conditions 3
Paths 3

Size

Total Lines 18
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 18
rs 9.4285
c 0
b 0
f 0
cc 3
eloc 11
nc 3
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\Trainable;
9
use Phpml\Math\Statistic\Mean;
10
use Phpml\Math\Statistic\StandardDeviation;
11
12
class NaiveBayes implements Classifier
13
{
14
    use Trainable, Predictable;
15
16
    const CONTINUOS    = 1;
17
    const NOMINAL    = 2;
18
    const EPSILON = 1e-10;
19
20
    /**
21
     * @var array
22
     */
23
    private $std = [];
24
25
    /**
26
     * @var array
27
     */
28
    private $mean= [];
29
30
    /**
31
     * @var array
32
     */
33
    private $discreteProb = [];
34
35
    /**
36
     * @var array
37
     */
38
    private $dataType = [];
39
40
    /**
41
     * @var array
42
     */
43
    private $p = [];
44
45
    /**
46
     * @var int
47
     */
48
    private $sampleCount = 0;
49
50
    /**
51
     * @var int
52
     */
53
    private $featureCount = 0;
54
55
    /**
56
     * @var array
57
     */
58
    private $labels = [];
59
60
    /**
61
     * @param array $samples
62
     * @param array $targets
63
     */
64
    public function train(array $samples, array $targets)
65
    {
66
        $this->samples = $samples;
67
        $this->targets = $targets;
68
        $this->sampleCount = count($samples);
69
        $this->featureCount = count($samples[0]);
70
71
        $labelCounts = array_count_values($targets);
72
        $this->labels = array_keys($labelCounts);
73
        foreach ($this->labels as $label) {
74
            $samples = $this->getSamplesByLabel($label);
75
            $this->p[$label] = count($samples) / $this->sampleCount;
76
            $this->calculateStatistics($label, $samples);
77
        }
78
    }
79
80
    /**
81
     * Calculates vital statistics for each label & feature. Stores these
82
     * values in private array in order to avoid repeated calculation
83
     * @param string $label
84
     * @param array $samples
85
     */
86
    private function calculateStatistics($label, $samples)
87
    {
88
        $this->std[$label] = array_fill(0, $this->featureCount, 0);
89
        $this->mean[$label]= array_fill(0, $this->featureCount, 0);
90
        $this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
91
        $this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
92
        for ($i=0; $i<$this->featureCount; $i++) {
93
            // Get the values of nth column in the samples array
94
            // Mean::arithmetic is called twice, can be optimized
95
            $values = array_column($samples, $i);
96
            $numValues = count($values);
97
            // if the values contain non-numeric data,
98
            // then it should be treated as nominal/categorical/discrete column
99
            if ($values !== array_filter($values, 'is_numeric')) {
100
                $this->dataType[$label][$i] = self::NOMINAL;
101
                $this->discreteProb[$label][$i] = array_count_values($values);
102
                $db = &$this->discreteProb[$label][$i];
103
                $db = array_map(function ($el) use ($numValues) {
104
                    return $el / $numValues;
105
                }, $db);
106
            } else {
107
                $this->mean[$label][$i] = Mean::arithmetic($values);
108
                // Add epsilon in order to avoid zero stdev
109
                $this->std[$label][$i] = 1e-10 + StandardDeviation::population($values, false);
110
            }
111
        }
112
    }
113
114
    /**
115
     * Calculates the probability P(label|sample_n)
116
     *
117
     * @param array $sample
118
     * @param int $feature
119
     * @param string $label
120
     * @return float
121
     */
122
    private function sampleProbability($sample, $feature, $label)
123
    {
124
        $value = $sample[$feature];
125
        if ($this->dataType[$label][$feature] == self::NOMINAL) {
126
            if (! isset($this->discreteProb[$label][$feature][$value]) ||
127
                $this->discreteProb[$label][$feature][$value] == 0) {
128
                return self::EPSILON;
129
            }
130
            return $this->discreteProb[$label][$feature][$value];
131
        }
132
        $std = $this->std[$label][$feature] ;
133
        $mean= $this->mean[$label][$feature];
134
        // Calculate the probability density by use of normal/Gaussian distribution
135
        // Ref: https://en.wikipedia.org/wiki/Normal_distribution
136
        //
137
        // In order to avoid numerical errors because of small or zero values,
138
        // some libraries adopt taking log of calculations such as
139
        // scikit-learn did.
140
        // (See : https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/naive_bayes.py)
141
        $pdf  =  -0.5 * log(2.0 * pi() * $std * $std);
142
        $pdf -= 0.5 * pow($value - $mean, 2) / ($std * $std);
143
        return $pdf;
144
    }
145
146
    /**
147
     * Return samples belonging to specific label
148
     * @param string $label
149
     * @return array
150
     */
151
    private function getSamplesByLabel($label)
152
    {
153
        $samples = array();
154
        for ($i=0; $i<$this->sampleCount; $i++) {
155
            if ($this->targets[$i] == $label) {
156
                $samples[] = $this->samples[$i];
157
            }
158
        }
159
        return $samples;
160
    }
161
162
    /**
163
     * @param array $sample
164
     * @return mixed
165
     */
166
    protected function predictSample(array $sample)
167
    {
168
        // Use NaiveBayes assumption for each label using:
169
        //	P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label)
0 ignored issues
show
Unused Code Comprehensibility introduced by
36% of this comment could be valid code. Did you maybe forget this after debugging?

Sometimes obsolete code just ends up commented out instead of removed. In this case it is better to remove the code once you have checked you do not need it.

The code might also have been commented out for debugging purposes. In this case it is vital that someone uncomments it again or your project may behave in very unexpected ways in production.

This check looks for comments that seem to be mostly valid code and reports them.

Loading history...
170
        // Then compare probability for each class to determine which label is most likely
171
        $predictions = array();
172
        foreach ($this->labels as $label) {
173
            $p = $this->p[$label];
174
            for ($i=0; $i<$this->featureCount; $i++) {
175
                $Plf = $this->sampleProbability($sample, $i, $label);
176
                $p += $Plf;
177
            }
178
            $predictions[$label] = $p;
179
        }
180
        arsort($predictions, SORT_NUMERIC);
181
        reset($predictions);
182
        return key($predictions);
183
    }
184
}
185