Completed
Push — master ( cb5a99...e603d6 )
by Arkadiusz
06:54
created

NaiveBayes::predictSample()   B

Complexity

Conditions 6
Paths 16

Size

Total Lines 30
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 30
rs 8.439
c 0
b 0
f 0
cc 6
eloc 20
nc 16
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\Trainable;
9
use Phpml\Math\Statistic\Mean;
10
use Phpml\Math\Statistic\StandardDeviation;
11
12
class NaiveBayes implements Classifier
13
{
14
    use Trainable, Predictable;
15
    const CONTINUOS    = 1;
16
    const NOMINAL    = 2;
17
    const EPSILON = 1e-10;
18
    private $std = array();
19
    private $mean= array();
20
    private $discreteProb = array();
21
    private $dataType = array();
22
    private $p = array();
23
    private $sampleCount = 0;
24
    private $featureCount = 0;
25
    private $labels = array();
26
    public function train(array $samples, array $targets)
27
    {
28
        $this->samples = $samples;
29
        $this->targets = $targets;
30
        $this->sampleCount = count($samples);
31
        $this->featureCount = count($samples[0]);
32
        // Get distinct targets
33
        $this->labels = $targets;
34
        array_unique($this->labels);
35
        foreach ($this->labels as $label) {
36
            $samples = $this->getSamplesByLabel($label);
37
            $this->p[$label] = count($samples) / $this->sampleCount;
38
            $this->calculateStatistics($label, $samples);
39
        }
40
    }
41
42
    /**
43
     * Calculates vital statistics for each label & feature. Stores these
44
     * values in private array in order to avoid repeated calculation
45
     * @param string $label
46
     * @param array $samples
47
     */
48
    private function calculateStatistics($label, $samples)
49
    {
50
        $this->std[$label] = array_fill(0, $this->featureCount, 0);
51
        $this->mean[$label]= array_fill(0, $this->featureCount, 0);
52
        $this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
53
        $this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
54
        for ($i=0; $i<$this->featureCount; $i++) {
55
            // Get the values of nth column in the samples array
56
            // Mean::arithmetic is called twice, can be optimized
57
            $values = array_column($samples, $i);
58
            $numValues = count($values);
59
            // if the values contain non-numeric data,
60
            // then it should be treated as nominal/categorical/discrete column
61
            if ($values !== array_filter($values, 'is_numeric')) {
62
                $this->dataType[$label][$i] = self::NOMINAL;
63
                $this->discreteProb[$label][$i] = array_count_values($values);
64
                $db = &$this->discreteProb[$label][$i];
65
                $db = array_map(function ($el) use ($numValues) {
66
                    return $el / $numValues;
67
                }, $db);
68
            } else {
69
                $this->mean[$label][$i] = Mean::arithmetic($values);
70
				// Add epsilon in order to avoid zero stdev
71
                $this->std[$label][$i] = 1e-10 + StandardDeviation::population($values, false);
72
            }
73
        }
74
    }
75
76
    /**
77
     * Calculates the probability P(label|sample_n)
78
	 *
79
     * @param array $sample
80
     * @param int $feature
81
     * @param string $label
82
     */
83
    private function sampleProbability($sample, $feature, $label)
84
    {
85
        $value = $sample[$feature];
86
        if ($this->dataType[$label][$feature] == self::NOMINAL) {
87
            if (! isset($this->discreteProb[$label][$feature][$value]) ||
88
                $this->discreteProb[$label][$feature][$value] == 0) {
89
                return self::EPSILON;
90
            }
91
            return $this->discreteProb[$label][$feature][$value];
92
        }
93
        $std = $this->std[$label][$feature] ;
94
        $mean= $this->mean[$label][$feature];
95
        // Calculate the probability density by use of normal/Gaussian distribution
96
        // Ref: https://en.wikipedia.org/wiki/Normal_distribution
97
		//
98
		// In order to avoid numerical errors because of small or zero values,
99
		// some libraries adopt taking log of calculations such as
100
		// scikit-learn did.
101
		// (See : https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/naive_bayes.py)
102
		$pdf  =  -0.5 * log(2.0 * pi() * $std * $std);
103
		$pdf -= 0.5 * pow($value - $mean, 2) / ($std * $std);
104
		return $pdf;
105
    }
106
107
    /**
108
     * Return samples belonging to specific label
109
     * @param string $label
110
     * @return array
111
     */
112
    private function getSamplesByLabel($label)
113
    {
114
        $samples = array();
115
        for ($i=0; $i<$this->sampleCount; $i++) {
116
            if ($this->targets[$i] == $label) {
117
                $samples[] = $this->samples[$i];
118
            }
119
        }
120
        return $samples;
121
    }
122
123
    /**
124
     * @param array $sample
125
     * @return mixed
126
     */
127
    protected function predictSample(array $sample)
128
    {
129
        $isArray = is_array($sample[0]);
130
        $samples = $sample;
131
        if (!$isArray) {
132
            $samples = array($sample);
133
        }
134
        $samplePredictions = array();
135
        foreach ($samples as $sample) {
136
            // Use NaiveBayes assumption for each label using:
137
            //	P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label)
0 ignored issues
show
Unused Code Comprehensibility introduced by
36% of this comment could be valid code. Did you maybe forget this after debugging?

Sometimes obsolete code just ends up commented out instead of removed. In this case it is better to remove the code once you have checked you do not need it.

The code might also have been commented out for debugging purposes. In this case it is vital that someone uncomments it again or your project may behave in very unexpected ways in production.

This check looks for comments that seem to be mostly valid code and reports them.

Loading history...
138
            // Then compare probability for each class to determine which label is most likely
139
            $predictions = array();
140
            foreach ($this->labels as $label) {
141
                $p = $this->p[$label];
142
                for ($i=0; $i<$this->featureCount; $i++) {
143
                    $Plf = $this->sampleProbability($sample, $i, $label);
144
                    $p += $Plf;
145
                }
146
                $predictions[$label] = $p;
147
            }
148
            arsort($predictions, SORT_NUMERIC);
149
            reset($predictions);
150
            $samplePredictions[] = key($predictions);
151
        }
152
        if (! $isArray) {
153
            return $samplePredictions[0];
154
        }
155
        return $samplePredictions;
156
    }
157
}
158