Completed
Pull Request — master (#30)
by
unknown
02:36
created

NaiveBayes   A

Complexity

Total Complexity 21

Size/Duplication

Total Lines 147
Duplicated Lines 0 %

Coupling/Cohesion

Components 1
Dependencies 4

Importance

Changes 0
Metric Value
wmc 21
lcom 1
cbo 4
dl 0
loc 147
rs 10
c 0
b 0
f 0

5 Methods

Rating   Name   Duplication   Size   Complexity  
A train() 0 15 2
B calculateStatistics() 0 27 3
B sampleProbability() 0 20 5
A getSamplesByLabel() 0 10 3
C predictSample() 0 34 8
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\Trainable;
9
use Phpml\Math\Statistic\Mean;
10
use Phpml\Math\Statistic\StandardDeviation;
11
12
class NaiveBayes implements Classifier
13
{
14
    use Trainable, Predictable;
15
    const CONTINUOS    = 1;
16
    const NOMINAL    = 2;
17
    const SMALL_VALUE = 1e-32;
18
    private $std = array();
19
    private $mean= array();
20
    private $discreteProb = array();
21
    private $dataType = array();
22
    private $p = array();
23
    private $sampleCount = 0;
24
    private $featureCount = 0;
25
    private $labels = array();
26
    public function train(array $samples, array $targets)
27
    {
28
        $this->samples = $samples;
29
        $this->targets = $targets;
30
        $this->sampleCount = count($samples);
31
        $this->featureCount = count($samples[0]);
32
        // Get distinct targets
33
        $this->labels = $targets;
34
        array_unique($this->labels);
35
        foreach ($this->labels as $label) {
36
            $samples = $this->getSamplesByLabel($label);
37
            $this->p[$label] = count($samples) / $this->sampleCount;
38
            $this->calculateStatistics($label, $samples);
39
        }
40
    }
41
42
    /**
43
     * Calculates vital statistics for each label & feature. Stores these
44
     * values in private array in order to avoid repeated calculation
45
     * @param string $label
46
     * @param array $samples
47
     */
48
    private function calculateStatistics($label, $samples)
49
    {
50
        $this->std[$label] = array_fill(0, $this->featureCount, 0);
51
        $this->mean[$label]= array_fill(0, $this->featureCount, 0);
52
        $this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
53
        $this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
54
        for ($i=0; $i<$this->featureCount; $i++) {
55
            // Get the values of nth column in the samples array
56
            // Mean::arithmetic is called twice, can be optimized
57
            $values = array_column($samples, $i);
58
            $numValues = count($values);
59
            $freq = array_count_values($values);
60
            // if values are of only a few types or the column contain categorical/string
61
            // values, then it should be treated as nominal/categorical/discrete column
62
            if (count($freq) <= $numValues / 2) {
63
                $this->dataType[$label][$i] = self::NOMINAL;
64
                $this->discreteProb[$label][$i] = $freq;
65
                $db = &$this->discreteProb[$label][$i];
66
                $db = array_map(function ($el) use ($numValues) {
67
                    return $el / $numValues;
68
                }, $db);
69
            } else {
70
                $this->mean[$label][$i] = Mean::arithmetic($values);
71
                $this->std[$label][$i] = StandardDeviation::population($values, false);
72
            }
73
        }
74
    }
75
76
    /**
77
     * Calculates the probability P(label|sample_n) assuming
78
     * the feature is a continuous value
79
     * @param array $sample
80
     * @param int $feature
81
     * @param string $label
82
     */
83
    private function sampleProbability($sample, $feature, $label)
84
    {
85
        $value = $sample[$feature];
86
        if ($this->dataType[$label][$feature] == self::NOMINAL) {
87
            if (! isset($this->discreteProb[$label][$feature][$value]) ||
88
                $this->discreteProb[$label][$feature][$value] == 0) {
89
                return self::SMALL_VALUE;
90
            }
91
            return $this->discreteProb[$label][$feature][$value];
92
        }
93
        $std = $this->std[$label][$feature] ;
94
        $mean= $this->mean[$label][$feature];
95
        $std2 = $std * $std;
96
        if ($std2 == 0) {
97
            $std2 = self::SMALL_VALUE;
98
        }
99
        // Calculate the probability density by use of normal/Gaussian distribution
100
        // Ref: https://en.wikipedia.org/wiki/Normal_distribution
101
        return (1 / sqrt(2 * $std2 * pi())) * exp(- pow($value - $mean, 2) / (2 * $std2));
102
    }
103
104
    /**
105
     * Return samples belonging to specific label
106
     * @param string $label
107
     * @return array
108
     */
109
    private function getSamplesByLabel($label)
110
    {
111
        $samples = array();
112
        for ($i=0; $i<$this->sampleCount; $i++) {
113
            if ($this->targets[$i] == $label) {
114
                $samples[] = $this->samples[$i];
115
            }
116
        }
117
        return $samples;
118
    }
119
120
    /**
121
     * @param array $sample
122
     * @return mixed
123
     */
124
    protected function predictSample(array $sample)
125
    {
126
        $isArray = is_array($sample[0]);
127
        $samples = $sample;
128
        if (!$isArray) {
129
            $samples = array($sample);
130
        }
131
        $samplePredictions = array();
132
        foreach ($samples as $sample) {
133
            // Use NaiveBayes assumption for each label using:
134
            //	P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label)
0 ignored issues
show
Unused Code Comprehensibility introduced by
36% of this comment could be valid code. Did you maybe forget this after debugging?

Sometimes obsolete code just ends up commented out instead of removed. In this case it is better to remove the code once you have checked you do not need it.

The code might also have been commented out for debugging purposes. In this case it is vital that someone uncomments it again or your project may behave in very unexpected ways in production.

This check looks for comments that seem to be mostly valid code and reports them.

Loading history...
135
            // Then compare probability for each class to determine which is most likely
136
            $predictions = array();
137
            foreach ($this->labels as $label) {
138
                $p = $this->p[$label];
139
                for ($i=0; $i<$this->featureCount; $i++) {
140
                    $Plf = $this->sampleProbability($sample, $i, $label);
141
                    // Correct the value for small and zero values
142
                    if (is_nan($Plf) || $Plf < self::SMALL_VALUE) {
143
                        $Plf = self::SMALL_VALUE;
144
                    }
145
                    $p *= $Plf;
146
                }
147
                $predictions[$label] = $p;
148
            }
149
            arsort($predictions, SORT_NUMERIC);
150
            reset($predictions);
151
            $samplePredictions[] = key($predictions);
152
        }
153
        if (! $isArray) {
154
            return $samplePredictions[0];
155
        }
156
        return $samplePredictions;
157
    }
158
}
159