Completed
Pull Request — master (#30)
by
unknown
02:25
created

NaiveBayes::getSamplesByLabel()   A

Complexity

Conditions 3
Paths 3

Size

Total Lines 8
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 8
rs 9.4285
c 0
b 0
f 0
cc 3
eloc 6
nc 3
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\Trainable;
9
use Phpml\Math\Statistic\Mean;
10
use Phpml\Math\Statistic\StandardDeviation;
11
12
class NaiveBayes implements Classifier
13
{
14
    use Trainable, Predictable;
15
	const CONTINUOS	= 1;
16
	const NOMINAL	= 2;
17
	const SMALL_VALUE = 1e-32;
18
	private $std = array();
19
	private $mean= array();
20
	private $discreteProb = array();
21
	private $dataType = array();
22
	private $p = array();
23
	private $sampleCount = 0;
24
	private $featureCount = 0;
25
	private $labels = array();
26
	public function train(array $samples, array $targets)
27
    {
28
        $this->samples = $samples;
29
        $this->targets = $targets;
30
		$this->sampleCount = count($samples);
31
		$this->featureCount = count($samples[0]);
32
		// Get distinct targets
33
		$this->labels = $targets;
34
		array_unique($this->labels);
35
		foreach($this->labels as $label)
36
		{
37
			$samples = $this->getSamplesByLabel($label);
38
			$this->p[$label] = count($samples) / $this->sampleCount;
39
			$this->calculateStatistics($label, $samples);
40
		}
41
    }
42
43
	/**
44
	 * Calculates vital statistics for each label & feature. Stores these
45
	 * values in private array in order to avoid repeated calculation
46
	 * @param string $label
47
	 * @param array $samples
48
	 */
49
	private function calculateStatistics($label, $samples)
50
	{
51
		$this->std[$label] = array_fill(0, $this->featureCount, 0);
52
		$this->mean[$label]= array_fill(0, $this->featureCount, 0);
53
		$this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
54
		$this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
55
		for($i=0; $i<$this->featureCount; $i++)
56
		{
57
			// Get the values of nth column in the samples array
58
			// Mean::arithmetic is called twice, can be optimized
59
			$values = array_column($samples, $i);
60
			$numValues = count($values);
61
			$freq = array_count_values($values);
62
			// if values are of only a few types or the column contain categorical/string
63
			// values, then it should be treated as nominal/categorical/discrete column
64
			if (count($freq) <= $numValues / 2)
65
			{
66
				$this->dataType[$label][$i] = self::NOMINAL;
67
				$this->discreteProb[$label][$i] = $freq;
68
				$db = &$this->discreteProb[$label][$i];
69
				$db = array_map(function($el) use ($numValues) {
70
					return $el / $numValues;
71
				}, $db);
72
			}
73
			else
74
			{
75
				$this->mean[$label][$i] = Mean::arithmetic($values);
76
				$this->std[$label][$i] = StandardDeviation::population($values, false);
77
			}
78
		}
79
	}
80
81
	/**
82
	 * Calculates the probability P(label|sample_n) assuming
83
	 * the feature is a continuous value
84
	 * @param array $sample
85
	 * @param int $feature
86
	 * @param string $label
87
	 */
88
	private function sampleProbability($sample, $feature, $label)
89
	{
90
		$value = $sample[$feature];
91
		if ($this->dataType[$label][$feature] == self::NOMINAL)
92
		{
93
			if (! isset($this->discreteProb[$label][$feature][$value]) ||
94
				$this->discreteProb[$label][$feature][$value] == 0)
95
				return self::SMALL_VALUE;
96
			return $this->discreteProb[$label][$feature][$value];
97
		}
98
		$std = $this->std[$label][$feature] ;
99
		$mean= $this->mean[$label][$feature];
100
		$std2 = $std * $std;
101
		if ($std2 == 0)
102
			$std2 = self::SMALL_VALUE;
103
		// Calculate the probability density by use of normal/Gaussian distribution
104
		// Ref: https://en.wikipedia.org/wiki/Normal_distribution
105
		return (1 / sqrt(2 * $std2 * pi())) * exp( - pow($value - $mean, 2) / (2 * $std2));
106
	}
107
108
	/**
109
	 * Return samples belonging to specific label
110
	 * @param string $label
111
	 * @return array
112
	 */
113
	private function getSamplesByLabel($label)
114
	{
115
		$samples = array();
116
		for($i=0; $i<$this->sampleCount; $i++)
117
			if ($this->targets[$i] == $label)
118
				$samples[] = $this->samples[$i];
119
		return $samples;
120
	}
121
122
    /**
123
     * @param array $sample
124
     * @return mixed
125
     */
126
    protected function predictSample(array $sample)
127
    {
128
		// Use NaiveBayes assumption for each label using:
129
		//	P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label)
0 ignored issues
show
Unused Code Comprehensibility introduced by
36% of this comment could be valid code. Did you maybe forget this after debugging?

Sometimes obsolete code just ends up commented out instead of removed. In this case it is better to remove the code once you have checked you do not need it.

The code might also have been commented out for debugging purposes. In this case it is vital that someone uncomments it again or your project may behave in very unexpected ways in production.

This check looks for comments that seem to be mostly valid code and reports them.

Loading history...
130
		// Then compare probability for each class to determine which is most likely
131
		$predictions = array();
132
		foreach($this->labels as $label)
133
		{
134
			$p = $this->p[$label];
135
			for($i=0; $i<$this->featureCount; $i++)
136
			{
137
				$Plf = $this->sampleProbability($sample, $i, $label);
138
				// Correct the value for small and zero values
139
				if (is_nan($Plf) || $Plf < self::SMALL_VALUE)
140
					$Plf = self::SMALL_VALUE;
141
				$p *= $Plf;
142
			}
143
			$predictions[$label] = $p;
144
		}
145
        arsort($predictions, SORT_NUMERIC);
146
        reset($predictions);
147
        return key($predictions);
148
    }
149
}
150