Completed
Pull Request — master (#30)
by
unknown
02:29
created

NaiveBayes::train()   A

Complexity

Conditions 2
Paths 2

Size

Total Lines 15
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 15
rs 9.4285
c 0
b 0
f 0
cc 2
eloc 11
nc 2
nop 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\Trainable;
9
use Phpml\Math\Statistic\Mean;
10
use Phpml\Math\Statistic\StandardDeviation;
11
12
class NaiveBayes implements Classifier
13
{
14
    use Trainable, Predictable;
15
    const CONTINUOS    = 1;
16
    const NOMINAL    = 2;
17
    const SMALL_VALUE = 1e-32;
18
    private $std = array();
19
    private $mean= array();
20
    private $discreteProb = array();
21
    private $dataType = array();
22
    private $p = array();
23
    private $sampleCount = 0;
24
    private $featureCount = 0;
25
    private $labels = array();
26
    public function train(array $samples, array $targets)
27
    {
28
        $this->samples = $samples;
29
        $this->targets = $targets;
30
        $this->sampleCount = count($samples);
31
        $this->featureCount = count($samples[0]);
32
        // Get distinct targets
33
        $this->labels = $targets;
34
        array_unique($this->labels);
35
        foreach ($this->labels as $label) {
36
            $samples = $this->getSamplesByLabel($label);
37
            $this->p[$label] = count($samples) / $this->sampleCount;
38
            $this->calculateStatistics($label, $samples);
39
        }
40
    }
41
42
    /**
43
     * Calculates vital statistics for each label & feature. Stores these
44
     * values in private array in order to avoid repeated calculation
45
     * @param string $label
46
     * @param array $samples
47
     */
48
    private function calculateStatistics($label, $samples)
49
    {
50
        $this->std[$label] = array_fill(0, $this->featureCount, 0);
51
        $this->mean[$label]= array_fill(0, $this->featureCount, 0);
52
        $this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
53
        $this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
54
        for ($i=0; $i<$this->featureCount; $i++) {
55
            // Get the values of nth column in the samples array
56
            // Mean::arithmetic is called twice, can be optimized
57
            $values = array_column($samples, $i);
58
            $numValues = count($values);
59
            // if the values contain non-numeric data,
60
            // then it should be treated as nominal/categorical/discrete column
61
            if ($values !== array_filter($values, 'is_numeric')) {
62
                $this->dataType[$label][$i] = self::NOMINAL;
63
                $this->discreteProb[$label][$i] = array_count_values($values);
64
                $db = &$this->discreteProb[$label][$i];
65
                $db = array_map(function ($el) use ($numValues) {
66
                    return $el / $numValues;
67
                }, $db);
68
            } else {
69
                $this->mean[$label][$i] = Mean::arithmetic($values);
70
                $this->std[$label][$i] = StandardDeviation::population($values, false);
71
            }
72
        }
73
    }
74
75
    /**
76
     * Calculates the probability P(label|sample_n) assuming
77
     * the feature is a continuous value
78
     * @param array $sample
79
     * @param int $feature
80
     * @param string $label
81
     */
82
    private function sampleProbability($sample, $feature, $label)
83
    {
84
        $value = $sample[$feature];
85
        if ($this->dataType[$label][$feature] == self::NOMINAL) {
86
            if (! isset($this->discreteProb[$label][$feature][$value]) ||
87
                $this->discreteProb[$label][$feature][$value] == 0) {
88
                return self::SMALL_VALUE;
89
            }
90
            return $this->discreteProb[$label][$feature][$value];
91
        }
92
        $std = $this->std[$label][$feature] ;
93
        $mean= $this->mean[$label][$feature];
94
        $std2 = $std * $std;
95
        if ($std2 == 0) {
96
            $std2 = self::SMALL_VALUE;
97
        }
98
        // Calculate the probability density by use of normal/Gaussian distribution
99
        // Ref: https://en.wikipedia.org/wiki/Normal_distribution
100
        return (1 / sqrt(2 * $std2 * pi())) * exp(- pow($value - $mean, 2) / (2 * $std2));
101
    }
102
103
    /**
104
     * Return samples belonging to specific label
105
     * @param string $label
106
     * @return array
107
     */
108
    private function getSamplesByLabel($label)
109
    {
110
        $samples = array();
111
        for ($i=0; $i<$this->sampleCount; $i++) {
112
            if ($this->targets[$i] == $label) {
113
                $samples[] = $this->samples[$i];
114
            }
115
        }
116
        return $samples;
117
    }
118
119
    /**
120
     * @param array $sample
121
     * @return mixed
122
     */
123
    protected function predictSample(array $sample)
124
    {
125
        $isArray = is_array($sample[0]);
126
        $samples = $sample;
127
        if (!$isArray) {
128
            $samples = array($sample);
129
        }
130
        $samplePredictions = array();
131
        foreach ($samples as $sample) {
132
            // Use NaiveBayes assumption for each label using:
133
            //	P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label)
0 ignored issues
show
Unused Code Comprehensibility introduced by
36% of this comment could be valid code. Did you maybe forget this after debugging?

Sometimes obsolete code just ends up commented out instead of removed. In this case it is better to remove the code once you have checked you do not need it.

The code might also have been commented out for debugging purposes. In this case it is vital that someone uncomments it again or your project may behave in very unexpected ways in production.

This check looks for comments that seem to be mostly valid code and reports them.

Loading history...
134
            // Then compare probability for each class to determine which is most likely
135
            $predictions = array();
136
            foreach ($this->labels as $label) {
137
                $p = $this->p[$label];
138
                for ($i=0; $i<$this->featureCount; $i++) {
139
                    $Plf = $this->sampleProbability($sample, $i, $label);
140
                    // Correct the value for small and zero values
141
                    if (is_nan($Plf) || $Plf < self::SMALL_VALUE) {
142
                        $Plf = self::SMALL_VALUE;
143
                    }
144
                    $p *= $Plf;
145
                }
146
                $predictions[$label] = $p;
147
            }
148
            arsort($predictions, SORT_NUMERIC);
149
            reset($predictions);
150
            $samplePredictions[] = key($predictions);
151
        }
152
        if (! $isArray) {
153
            return $samplePredictions[0];
154
        }
155
        return $samplePredictions;
156
    }
157
}
158