@@ -12,112 +12,112 @@ discard block |
||
| 12 | 12 | class NaiveBayes implements Classifier |
| 13 | 13 | { |
| 14 | 14 | use Trainable, Predictable; |
| 15 | - const CONTINUOS = 1; |
|
| 16 | - const NOMINAL = 2; |
|
| 17 | - const SMALL_VALUE = 1e-32; |
|
| 18 | - private $std = array(); |
|
| 19 | - private $mean= array(); |
|
| 20 | - private $discreteProb = array(); |
|
| 21 | - private $dataType = array(); |
|
| 22 | - private $p = array(); |
|
| 23 | - private $sampleCount = 0; |
|
| 24 | - private $featureCount = 0; |
|
| 25 | - private $labels = array(); |
|
| 26 | - public function train(array $samples, array $targets) |
|
| 15 | + const CONTINUOS = 1; |
|
| 16 | + const NOMINAL = 2; |
|
| 17 | + const SMALL_VALUE = 1e-32; |
|
| 18 | + private $std = array(); |
|
| 19 | + private $mean= array(); |
|
| 20 | + private $discreteProb = array(); |
|
| 21 | + private $dataType = array(); |
|
| 22 | + private $p = array(); |
|
| 23 | + private $sampleCount = 0; |
|
| 24 | + private $featureCount = 0; |
|
| 25 | + private $labels = array(); |
|
| 26 | + public function train(array $samples, array $targets) |
|
| 27 | 27 | { |
| 28 | 28 | $this->samples = $samples; |
| 29 | 29 | $this->targets = $targets; |
| 30 | - $this->sampleCount = count($samples); |
|
| 31 | - $this->featureCount = count($samples[0]); |
|
| 32 | - // Get distinct targets |
|
| 33 | - $this->labels = $targets; |
|
| 34 | - array_unique($this->labels); |
|
| 35 | - foreach($this->labels as $label) |
|
| 36 | - { |
|
| 37 | - $samples = $this->getSamplesByLabel($label); |
|
| 38 | - $this->p[$label] = count($samples) / $this->sampleCount; |
|
| 39 | - $this->calculateStatistics($label, $samples); |
|
| 40 | - } |
|
| 30 | + $this->sampleCount = count($samples); |
|
| 31 | + $this->featureCount = count($samples[0]); |
|
| 32 | + // Get distinct targets |
|
| 33 | + $this->labels = $targets; |
|
| 34 | + array_unique($this->labels); |
|
| 35 | + foreach($this->labels as $label) |
|
| 36 | + { |
|
| 37 | + $samples = $this->getSamplesByLabel($label); |
|
| 38 | + $this->p[$label] = count($samples) / $this->sampleCount; |
|
| 39 | + $this->calculateStatistics($label, $samples); |
|
| 40 | + } |
|
| 41 | 41 | } |
| 42 | 42 | |
| 43 | - /** |
|
| 44 | - * Calculates vital statistics for each label & feature. Stores these |
|
| 45 | - * values in private array in order to avoid repeated calculation |
|
| 46 | - * @param string $label |
|
| 47 | - * @param array $samples |
|
| 48 | - */ |
|
| 49 | - private function calculateStatistics($label, $samples) |
|
| 50 | - { |
|
| 51 | - $this->std[$label] = array_fill(0, $this->featureCount, 0); |
|
| 52 | - $this->mean[$label]= array_fill(0, $this->featureCount, 0); |
|
| 53 | - $this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS); |
|
| 54 | - $this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS); |
|
| 55 | - for($i=0; $i<$this->featureCount; $i++) |
|
| 56 | - { |
|
| 57 | - // Get the values of nth column in the samples array |
|
| 58 | - // Mean::arithmetic is called twice, can be optimized |
|
| 59 | - $values = array_column($samples, $i); |
|
| 60 | - $numValues = count($values); |
|
| 61 | - $freq = array_count_values($values); |
|
| 62 | - // if values are of only a few types or the column contain categorical/string |
|
| 63 | - // values, then it should be treated as nominal/categorical/discrete column |
|
| 64 | - if (count($freq) <= $numValues / 2) |
|
| 65 | - { |
|
| 66 | - $this->dataType[$label][$i] = self::NOMINAL; |
|
| 67 | - $this->discreteProb[$label][$i] = $freq; |
|
| 68 | - $db = &$this->discreteProb[$label][$i]; |
|
| 69 | - $db = array_map(function($el) use ($numValues) { |
|
| 70 | - return $el / $numValues; |
|
| 71 | - }, $db); |
|
| 72 | - } |
|
| 73 | - else |
|
| 74 | - { |
|
| 75 | - $this->mean[$label][$i] = Mean::arithmetic($values); |
|
| 76 | - $this->std[$label][$i] = StandardDeviation::population($values, false); |
|
| 77 | - } |
|
| 78 | - } |
|
| 79 | - } |
|
| 43 | + /** |
|
| 44 | + * Calculates vital statistics for each label & feature. Stores these |
|
| 45 | + * values in private array in order to avoid repeated calculation |
|
| 46 | + * @param string $label |
|
| 47 | + * @param array $samples |
|
| 48 | + */ |
|
| 49 | + private function calculateStatistics($label, $samples) |
|
| 50 | + { |
|
| 51 | + $this->std[$label] = array_fill(0, $this->featureCount, 0); |
|
| 52 | + $this->mean[$label]= array_fill(0, $this->featureCount, 0); |
|
| 53 | + $this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS); |
|
| 54 | + $this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS); |
|
| 55 | + for($i=0; $i<$this->featureCount; $i++) |
|
| 56 | + { |
|
| 57 | + // Get the values of nth column in the samples array |
|
| 58 | + // Mean::arithmetic is called twice, can be optimized |
|
| 59 | + $values = array_column($samples, $i); |
|
| 60 | + $numValues = count($values); |
|
| 61 | + $freq = array_count_values($values); |
|
| 62 | + // if values are of only a few types or the column contain categorical/string |
|
| 63 | + // values, then it should be treated as nominal/categorical/discrete column |
|
| 64 | + if (count($freq) <= $numValues / 2) |
|
| 65 | + { |
|
| 66 | + $this->dataType[$label][$i] = self::NOMINAL; |
|
| 67 | + $this->discreteProb[$label][$i] = $freq; |
|
| 68 | + $db = &$this->discreteProb[$label][$i]; |
|
| 69 | + $db = array_map(function($el) use ($numValues) { |
|
| 70 | + return $el / $numValues; |
|
| 71 | + }, $db); |
|
| 72 | + } |
|
| 73 | + else |
|
| 74 | + { |
|
| 75 | + $this->mean[$label][$i] = Mean::arithmetic($values); |
|
| 76 | + $this->std[$label][$i] = StandardDeviation::population($values, false); |
|
| 77 | + } |
|
| 78 | + } |
|
| 79 | + } |
|
| 80 | 80 | |
| 81 | - /** |
|
| 82 | - * Calculates the probability P(label|sample_n) assuming |
|
| 83 | - * the feature is a continuous value |
|
| 84 | - * @param array $sample |
|
| 85 | - * @param int $feature |
|
| 86 | - * @param string $label |
|
| 87 | - */ |
|
| 88 | - private function sampleProbability($sample, $feature, $label) |
|
| 89 | - { |
|
| 90 | - $value = $sample[$feature]; |
|
| 91 | - if ($this->dataType[$label][$feature] == self::NOMINAL) |
|
| 92 | - { |
|
| 93 | - if (! isset($this->discreteProb[$label][$feature][$value]) || |
|
| 94 | - $this->discreteProb[$label][$feature][$value] == 0) |
|
| 95 | - return self::SMALL_VALUE; |
|
| 96 | - return $this->discreteProb[$label][$feature][$value]; |
|
| 97 | - } |
|
| 98 | - $std = $this->std[$label][$feature] ; |
|
| 99 | - $mean= $this->mean[$label][$feature]; |
|
| 100 | - $std2 = $std * $std; |
|
| 101 | - if ($std2 == 0) |
|
| 102 | - $std2 = self::SMALL_VALUE; |
|
| 103 | - // Calculate the probability density by use of normal/Gaussian distribution |
|
| 104 | - // Ref: https://en.wikipedia.org/wiki/Normal_distribution |
|
| 105 | - return (1 / sqrt(2 * $std2 * pi())) * exp( - pow($value - $mean, 2) / (2 * $std2)); |
|
| 106 | - } |
|
| 81 | + /** |
|
| 82 | + * Calculates the probability P(label|sample_n) assuming |
|
| 83 | + * the feature is a continuous value |
|
| 84 | + * @param array $sample |
|
| 85 | + * @param int $feature |
|
| 86 | + * @param string $label |
|
| 87 | + */ |
|
| 88 | + private function sampleProbability($sample, $feature, $label) |
|
| 89 | + { |
|
| 90 | + $value = $sample[$feature]; |
|
| 91 | + if ($this->dataType[$label][$feature] == self::NOMINAL) |
|
| 92 | + { |
|
| 93 | + if (! isset($this->discreteProb[$label][$feature][$value]) || |
|
| 94 | + $this->discreteProb[$label][$feature][$value] == 0) |
|
| 95 | + return self::SMALL_VALUE; |
|
| 96 | + return $this->discreteProb[$label][$feature][$value]; |
|
| 97 | + } |
|
| 98 | + $std = $this->std[$label][$feature] ; |
|
| 99 | + $mean= $this->mean[$label][$feature]; |
|
| 100 | + $std2 = $std * $std; |
|
| 101 | + if ($std2 == 0) |
|
| 102 | + $std2 = self::SMALL_VALUE; |
|
| 103 | + // Calculate the probability density by use of normal/Gaussian distribution |
|
| 104 | + // Ref: https://en.wikipedia.org/wiki/Normal_distribution |
|
| 105 | + return (1 / sqrt(2 * $std2 * pi())) * exp( - pow($value - $mean, 2) / (2 * $std2)); |
|
| 106 | + } |
|
| 107 | 107 | |
| 108 | - /** |
|
| 109 | - * Return samples belonging to specific label |
|
| 110 | - * @param string $label |
|
| 111 | - * @return array |
|
| 112 | - */ |
|
| 113 | - private function getSamplesByLabel($label) |
|
| 114 | - { |
|
| 115 | - $samples = array(); |
|
| 116 | - for($i=0; $i<$this->sampleCount; $i++) |
|
| 117 | - if ($this->targets[$i] == $label) |
|
| 118 | - $samples[] = $this->samples[$i]; |
|
| 119 | - return $samples; |
|
| 120 | - } |
|
| 108 | + /** |
|
| 109 | + * Return samples belonging to specific label |
|
| 110 | + * @param string $label |
|
| 111 | + * @return array |
|
| 112 | + */ |
|
| 113 | + private function getSamplesByLabel($label) |
|
| 114 | + { |
|
| 115 | + $samples = array(); |
|
| 116 | + for($i=0; $i<$this->sampleCount; $i++) |
|
| 117 | + if ($this->targets[$i] == $label) |
|
| 118 | + $samples[] = $this->samples[$i]; |
|
| 119 | + return $samples; |
|
| 120 | + } |
|
| 121 | 121 | |
| 122 | 122 | /** |
| 123 | 123 | * @param array $sample |
@@ -125,23 +125,23 @@ discard block |
||
| 125 | 125 | */ |
| 126 | 126 | protected function predictSample(array $sample) |
| 127 | 127 | { |
| 128 | - // Use NaiveBayes assumption for each label using: |
|
| 129 | - // P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label) |
|
| 130 | - // Then compare probability for each class to determine which is most likely |
|
| 131 | - $predictions = array(); |
|
| 132 | - foreach($this->labels as $label) |
|
| 133 | - { |
|
| 134 | - $p = $this->p[$label]; |
|
| 135 | - for($i=0; $i<$this->featureCount; $i++) |
|
| 136 | - { |
|
| 137 | - $Plf = $this->sampleProbability($sample, $i, $label); |
|
| 138 | - // Correct the value for small and zero values |
|
| 139 | - if (is_nan($Plf) || $Plf < self::SMALL_VALUE) |
|
| 140 | - $Plf = self::SMALL_VALUE; |
|
| 141 | - $p *= $Plf; |
|
| 142 | - } |
|
| 143 | - $predictions[$label] = $p; |
|
| 144 | - } |
|
| 128 | + // Use NaiveBayes assumption for each label using: |
|
| 129 | + // P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label) |
|
| 130 | + // Then compare probability for each class to determine which is most likely |
|
| 131 | + $predictions = array(); |
|
| 132 | + foreach($this->labels as $label) |
|
| 133 | + { |
|
| 134 | + $p = $this->p[$label]; |
|
| 135 | + for($i=0; $i<$this->featureCount; $i++) |
|
| 136 | + { |
|
| 137 | + $Plf = $this->sampleProbability($sample, $i, $label); |
|
| 138 | + // Correct the value for small and zero values |
|
| 139 | + if (is_nan($Plf) || $Plf < self::SMALL_VALUE) |
|
| 140 | + $Plf = self::SMALL_VALUE; |
|
| 141 | + $p *= $Plf; |
|
| 142 | + } |
|
| 143 | + $predictions[$label] = $p; |
|
| 144 | + } |
|
| 145 | 145 | arsort($predictions, SORT_NUMERIC); |
| 146 | 146 | reset($predictions); |
| 147 | 147 | return key($predictions); |