Passed
Pull Request — master (#167)
by
unknown
02:11
created

Normalizer::preprocess()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 4
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
eloc 2
nc 1
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Preprocessing;
6
7
use Phpml\Exception\NormalizerException;
8
use Phpml\Math\Statistic\Mean;
9
use Phpml\Math\Statistic\StandardDeviation;
10
11
class Normalizer implements Preprocessor
12
{
13
    public const NORM_L1 = 1;
14
15
    public const NORM_L2 = 2;
16
17
    public const NORM_STD = 3;
18
19
    /**
20
     * @var int
21
     */
22
    private $norm;
23
24
    /**
25
     * @var bool
26
     */
27
    private $fitted = false;
28
29
    /**
30
     * @var array
31
     */
32
    private $std = [];
33
34
    /**
35
     * @var array
36
     */
37
    private $mean = [];
38
39
    /**
40
     * @throws NormalizerException
41
     */
42
    public function __construct(int $norm = self::NORM_L2)
43
    {
44
        if (!in_array($norm, [self::NORM_L1, self::NORM_L2, self::NORM_STD])) {
45
            throw NormalizerException::unknownNorm();
46
        }
47
48
        $this->norm = $norm;
49
    }
50
51
    /**
52
     * @param array $samples
53
     */
54
    public function preprocess(array $samples)
55
    {
56
      $this->fit($samples);
57
    }
58
59
    public function fit(array $samples): void
60
    {
61
        if ($this->fitted) {
62
            return;
63
        }
64
65
        if ($this->norm == self::NORM_STD) {
66
            $features = range(0, count($samples[0]) - 1);
67
            foreach ($features as $i) {
68
                $values = array_column($samples, $i);
69
                $this->std[$i] = StandardDeviation::population($values);
70
                $this->mean[$i] = Mean::arithmetic($values);
71
            }
72
        }
73
74
        $this->fitted = true;
75
    }
76
77
    public function transform(array &$samples): void
78
    {
79
        $methods = [
80
            self::NORM_L1 => 'normalizeL1',
81
            self::NORM_L2 => 'normalizeL2',
82
            self::NORM_STD => 'normalizeSTD',
83
        ];
84
        $method = $methods[$this->norm];
85
86
        $this->fit($samples);
87
88
        foreach ($samples as &$sample) {
89
            $this->{$method}($sample);
90
        }
91
    }
92
93
    private function normalizeL1(array &$sample): void
94
    {
95
        $norm1 = 0;
96
        foreach ($sample as $feature) {
97
            $norm1 += abs($feature);
98
        }
99
100
        if ($norm1 == 0) {
101
            $count = count($sample);
102
            $sample = array_fill(0, $count, 1.0 / $count);
103
        } else {
104
            foreach ($sample as &$feature) {
105
                $feature /= $norm1;
106
            }
107
        }
108
    }
109
110
    private function normalizeL2(array &$sample): void
111
    {
112
        $norm2 = 0;
113
        foreach ($sample as $feature) {
114
            $norm2 += $feature * $feature;
115
        }
116
117
        $norm2 = sqrt((float) $norm2);
118
119
        if ($norm2 == 0) {
120
            $sample = array_fill(0, count($sample), 1);
121
        } else {
122
            foreach ($sample as &$feature) {
123
                $feature /= $norm2;
124
            }
125
        }
126
    }
127
128
    private function normalizeSTD(array &$sample): void
129
    {
130
        foreach ($sample as $i => $val) {
131
            if ($this->std[$i] != 0) {
132
                $sample[$i] = ($sample[$i] - $this->mean[$i]) / $this->std[$i];
133
            } else {
134
                // Same value for all samples.
135
                $sample[$i] = 0;
136
            }
137
        }
138
    }
139
}
140