1 | <?php |
||
2 | |||
3 | declare(strict_types=1); |
||
4 | |||
5 | namespace Phpml\Preprocessing; |
||
6 | |||
7 | use Phpml\Exception\NormalizerException; |
||
8 | use Phpml\Math\Statistic\Mean; |
||
9 | use Phpml\Math\Statistic\StandardDeviation; |
||
10 | |||
11 | class Normalizer implements Preprocessor |
||
12 | { |
||
13 | public const NORM_L1 = 1; |
||
14 | |||
15 | public const NORM_L2 = 2; |
||
16 | |||
17 | public const NORM_STD = 3; |
||
18 | |||
19 | /** |
||
20 | * @var int |
||
21 | */ |
||
22 | private $norm; |
||
23 | |||
24 | /** |
||
25 | * @var bool |
||
26 | */ |
||
27 | private $fitted = false; |
||
28 | |||
29 | /** |
||
30 | * @var array |
||
31 | */ |
||
32 | private $std = []; |
||
33 | |||
34 | /** |
||
35 | * @var array |
||
36 | */ |
||
37 | private $mean = []; |
||
38 | |||
39 | /** |
||
40 | * @throws NormalizerException |
||
41 | */ |
||
42 | public function __construct(int $norm = self::NORM_L2) |
||
43 | { |
||
44 | if (!in_array($norm, [self::NORM_L1, self::NORM_L2, self::NORM_STD], true)) { |
||
45 | throw new NormalizerException('Unknown norm supplied.'); |
||
46 | } |
||
47 | |||
48 | $this->norm = $norm; |
||
49 | } |
||
50 | |||
51 | public function fit(array $samples, ?array $targets = null): void |
||
52 | { |
||
53 | if ($this->fitted) { |
||
54 | return; |
||
55 | } |
||
56 | |||
57 | if ($this->norm === self::NORM_STD) { |
||
58 | $features = range(0, count($samples[0]) - 1); |
||
59 | foreach ($features as $i) { |
||
60 | $values = array_column($samples, $i); |
||
61 | $this->std[$i] = StandardDeviation::population($values); |
||
62 | $this->mean[$i] = Mean::arithmetic($values); |
||
63 | } |
||
64 | } |
||
65 | |||
66 | $this->fitted = true; |
||
67 | } |
||
68 | |||
69 | public function transform(array &$samples, ?array &$targets = null): void |
||
70 | { |
||
71 | $methods = [ |
||
72 | self::NORM_L1 => 'normalizeL1', |
||
73 | self::NORM_L2 => 'normalizeL2', |
||
74 | self::NORM_STD => 'normalizeSTD', |
||
75 | ]; |
||
76 | $method = $methods[$this->norm]; |
||
77 | |||
78 | $this->fit($samples); |
||
79 | |||
80 | foreach ($samples as &$sample) { |
||
81 | $this->{$method}($sample); |
||
82 | } |
||
83 | } |
||
84 | |||
85 | private function normalizeL1(array &$sample): void |
||
86 | { |
||
87 | $norm1 = 0; |
||
88 | foreach ($sample as $feature) { |
||
89 | $norm1 += abs($feature); |
||
90 | } |
||
91 | |||
92 | if ($norm1 == 0) { |
||
0 ignored issues
–
show
introduced
by
Loading history...
|
|||
93 | $count = count($sample); |
||
94 | $sample = array_fill(0, $count, 1.0 / $count); |
||
95 | } else { |
||
96 | array_walk($sample, function (&$feature) use ($norm1): void { |
||
97 | $feature /= $norm1; |
||
98 | }); |
||
99 | } |
||
100 | } |
||
101 | |||
102 | private function normalizeL2(array &$sample): void |
||
103 | { |
||
104 | $norm2 = 0; |
||
105 | foreach ($sample as $feature) { |
||
106 | $norm2 += $feature * $feature; |
||
107 | } |
||
108 | |||
109 | $norm2 **= .5; |
||
110 | |||
111 | if ($norm2 == 0) { |
||
0 ignored issues
–
show
|
|||
112 | $sample = array_fill(0, count($sample), 1); |
||
113 | } else { |
||
114 | array_walk($sample, function (&$feature) use ($norm2): void { |
||
115 | $feature /= $norm2; |
||
116 | }); |
||
117 | } |
||
118 | } |
||
119 | |||
120 | private function normalizeSTD(array &$sample): void |
||
121 | { |
||
122 | foreach (array_keys($sample) as $i) { |
||
123 | if ($this->std[$i] != 0) { |
||
124 | $sample[$i] = ($sample[$i] - $this->mean[$i]) / $this->std[$i]; |
||
125 | } else { |
||
126 | // Same value for all samples. |
||
127 | $sample[$i] = 0; |
||
128 | } |
||
129 | } |
||
130 | } |
||
131 | } |
||
132 |