Completed
Pull Request — master (#36)
by
unknown
03:44 queued 01:01
created

NaiveBayes::calculateStatistics()   B

Complexity

Conditions 3
Paths 3

Size

Total Lines 27
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 27
rs 8.8571
c 0
b 0
f 0
cc 3
eloc 18
nc 3
nop 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\Trainable;
9
use Phpml\Math\Statistic\Mean;
10
use Phpml\Math\Statistic\StandardDeviation;
11
12
class NaiveBayes implements Classifier
0 ignored issues
show
Bug introduced by
Possible parse error: class missing opening or closing brace
Loading history...
13
{
14
    use Trainable, Predictable;
15
16
    const CONTINUOS    = 1;
17
    const NOMINAL    = 2;
18
    const EPSILON = 1e-10;
19
20
    /**
21
     * @var array
22
     */
23
    private $std = [];
24
25
    /**
26
     * @var array
27
     */
28
    private $mean= [];
29
30
    /**
31
     * @var array
32
     */
33
    private $discreteProb = [];
34
35
    /**
36
     * @var array
37
     */
38
    private $dataType = [];
39
40
    /**
41
     * @var array
42
     */
43
    private $p = [];
44
45
    /**
46
     * @var int
47
     */
48
    private $sampleCount = 0;
49
50
    /**
51
     * @var int
52
     */
53
    private $featureCount = 0;
54
55
    /**
56
     * @var array
57
     */
58
    private $labels = [];
59
60
    /**
61
     * @param array $samples
62
     * @param array $targets
63
     */
64
    public function train(array $samples, array $targets)
65
    {
66
        $this->samples = array_merge($this->samples, $samples);
67
        $this->targets = array_merge($this->targets, $targets);
68
        $this->sampleCount = count($this->samples);
69
        $this->featureCount = count($this->samples[0]);
70
71
<<<<<<< HEAD
0 ignored issues
show
Bug introduced by
This code did not parse for me. Apparently, there is an error somewhere around this line:

Syntax error, unexpected T_SL
Loading history...
72
        $labelCounts = array_count_values($targets);
73
=======
74
        $labelCounts = array_count_values($this->targets);
75
>>>>>>> refs/remotes/php-ai/master
76
        $this->labels = array_keys($labelCounts);
77
        foreach ($this->labels as $label) {
78
            $samples = $this->getSamplesByLabel($label);
79
            $this->p[$label] = count($samples) / $this->sampleCount;
80
            $this->calculateStatistics($label, $samples);
81
        }
82
    }
83
84
    /**
85
     * Calculates vital statistics for each label & feature. Stores these
86
     * values in private array in order to avoid repeated calculation
87
     * @param string $label
88
     * @param array $samples
89
     */
90
    private function calculateStatistics($label, $samples)
91
    {
92
        $this->std[$label] = array_fill(0, $this->featureCount, 0);
93
        $this->mean[$label]= array_fill(0, $this->featureCount, 0);
94
        $this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
95
        $this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
96
        for ($i=0; $i<$this->featureCount; $i++) {
97
            // Get the values of nth column in the samples array
98
            // Mean::arithmetic is called twice, can be optimized
99
            $values = array_column($samples, $i);
100
            $numValues = count($values);
101
            // if the values contain non-numeric data,
102
            // then it should be treated as nominal/categorical/discrete column
103
            if ($values !== array_filter($values, 'is_numeric')) {
104
                $this->dataType[$label][$i] = self::NOMINAL;
105
                $this->discreteProb[$label][$i] = array_count_values($values);
106
                $db = &$this->discreteProb[$label][$i];
107
                $db = array_map(function ($el) use ($numValues) {
108
                    return $el / $numValues;
109
                }, $db);
110
            } else {
111
                $this->mean[$label][$i] = Mean::arithmetic($values);
112
                // Add epsilon in order to avoid zero stdev
113
                $this->std[$label][$i] = 1e-10 + StandardDeviation::population($values, false);
114
            }
115
        }
116
    }
117
118
    /**
119
     * Calculates the probability P(label|sample_n)
120
     *
121
     * @param array $sample
122
     * @param int $feature
123
     * @param string $label
124
     * @return float
125
     */
126
    private function sampleProbability($sample, $feature, $label)
127
    {
128
        $value = $sample[$feature];
129
        if ($this->dataType[$label][$feature] == self::NOMINAL) {
130
            if (! isset($this->discreteProb[$label][$feature][$value]) ||
131
                $this->discreteProb[$label][$feature][$value] == 0) {
132
                return self::EPSILON;
133
            }
134
            return $this->discreteProb[$label][$feature][$value];
135
        }
136
        $std = $this->std[$label][$feature] ;
137
        $mean= $this->mean[$label][$feature];
138
        // Calculate the probability density by use of normal/Gaussian distribution
139
        // Ref: https://en.wikipedia.org/wiki/Normal_distribution
140
        //
141
        // In order to avoid numerical errors because of small or zero values,
142
        // some libraries adopt taking log of calculations such as
143
        // scikit-learn did.
144
        // (See : https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/naive_bayes.py)
145
        $pdf  =  -0.5 * log(2.0 * pi() * $std * $std);
146
        $pdf -= 0.5 * pow($value - $mean, 2) / ($std * $std);
147
        return $pdf;
148
    }
149
150
    /**
151
     * Return samples belonging to specific label
152
     * @param string $label
153
     * @return array
154
     */
155
    private function getSamplesByLabel($label)
156
    {
157
        $samples = [];
158
        for ($i=0; $i<$this->sampleCount; $i++) {
159
            if ($this->targets[$i] == $label) {
160
                $samples[] = $this->samples[$i];
161
            }
162
        }
163
        return $samples;
164
    }
165
166
    /**
167
     * @param array $sample
168
     * @return mixed
169
     */
170
    protected function predictSample(array $sample)
171
    {
172
        // Use NaiveBayes assumption for each label using:
173
        //	P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label)
174
        // Then compare probability for each class to determine which label is most likely
175
<<<<<<< HEAD
176
        $predictions = array();
177
=======
178
        $predictions = [];
179
>>>>>>> refs/remotes/php-ai/master
180
        foreach ($this->labels as $label) {
181
            $p = $this->p[$label];
182
            for ($i=0; $i<$this->featureCount; $i++) {
183
                $Plf = $this->sampleProbability($sample, $i, $label);
184
                $p += $Plf;
185
            }
186
            $predictions[$label] = $p;
187
        }
188
        arsort($predictions, SORT_NUMERIC);
189
        reset($predictions);
190
        return key($predictions);
191
    }
192
}
193