LDA::__construct()   B
last analyzed

Complexity

Conditions 9
Paths 7

Size

Total Lines 20
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 10
dl 0
loc 20
rs 8.0555
c 0
b 0
f 0
cc 9
nc 7
nop 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\DimensionReduction;
6
7
use Phpml\Exception\InvalidArgumentException;
8
use Phpml\Exception\InvalidOperationException;
9
use Phpml\Math\Matrix;
10
11
class LDA extends EigenTransformerBase
12
{
13
    /**
14
     * @var bool
15
     */
16
    public $fit = false;
17
18
    /**
19
     * @var array
20
     */
21
    public $labels = [];
22
23
    /**
24
     * @var array
25
     */
26
    public $means = [];
27
28
    /**
29
     * @var array
30
     */
31
    public $counts = [];
32
33
    /**
34
     * @var float[]
35
     */
36
    public $overallMean = [];
37
38
    /**
39
     * Linear Discriminant Analysis (LDA) is used to reduce the dimensionality
40
     * of the data. Unlike Principal Component Analysis (PCA), it is a supervised
41
     * technique that requires the class labels in order to fit the data to a
42
     * lower dimensional space. <br><br>
43
     * The algorithm can be initialized by speciyfing
44
     * either with the totalVariance(a value between 0.1 and 0.99)
45
     * or numFeatures (number of features in the dataset) to be preserved.
46
     *
47
     * @param float|null $totalVariance Total explained variance to be preserved
48
     * @param int|null   $numFeatures   Number of features to be preserved
49
     *
50
     * @throws InvalidArgumentException
51
     */
52
    public function __construct(?float $totalVariance = null, ?int $numFeatures = null)
53
    {
54
        if ($totalVariance !== null && ($totalVariance < 0.1 || $totalVariance > 0.99)) {
55
            throw new InvalidArgumentException('Total variance can be a value between 0.1 and 0.99');
56
        }
57
58
        if ($numFeatures !== null && $numFeatures <= 0) {
59
            throw new InvalidArgumentException('Number of features to be preserved should be greater than 0');
60
        }
61
62
        if (($totalVariance !== null) === ($numFeatures !== null)) {
63
            throw new InvalidArgumentException('Either totalVariance or numFeatures should be specified in order to run the algorithm');
64
        }
65
66
        if ($numFeatures !== null) {
67
            $this->numFeatures = $numFeatures;
68
        }
69
70
        if ($totalVariance !== null) {
71
            $this->totalVariance = $totalVariance;
72
        }
73
    }
74
75
    /**
76
     * Trains the algorithm to transform the given data to a lower dimensional space.
77
     */
78
    public function fit(array $data, array $classes): array
79
    {
80
        $this->labels = $this->getLabels($classes);
81
        $this->means = $this->calculateMeans($data, $classes);
82
83
        $sW = $this->calculateClassVar($data, $classes);
84
        $sB = $this->calculateClassCov();
85
86
        $S = $sW->inverse()->multiply($sB);
87
        $this->eigenDecomposition($S->toArray());
88
89
        $this->fit = true;
90
91
        return $this->reduce($data);
92
    }
93
94
    /**
95
     * Transforms the given sample to a lower dimensional vector by using
96
     * the eigenVectors obtained in the last run of <code>fit</code>.
97
     *
98
     * @throws InvalidOperationException
99
     */
100
    public function transform(array $sample): array
101
    {
102
        if (!$this->fit) {
103
            throw new InvalidOperationException('LDA has not been fitted with respect to original dataset, please run LDA::fit() first');
104
        }
105
106
        if (!is_array($sample[0])) {
107
            $sample = [$sample];
108
        }
109
110
        return $this->reduce($sample);
111
    }
112
113
    /**
114
     * Returns unique labels in the dataset
115
     */
116
    protected function getLabels(array $classes): array
117
    {
118
        $counts = array_count_values($classes);
119
120
        return array_keys($counts);
121
    }
122
123
    /**
124
     * Calculates mean of each column for each class and returns
125
     * n by m matrix where n is number of labels and m is number of columns
126
     */
127
    protected function calculateMeans(array $data, array $classes): array
128
    {
129
        $means = [];
130
        $counts = [];
131
        $overallMean = array_fill(0, count($data[0]), 0.0);
132
133
        foreach ($data as $index => $row) {
134
            $label = array_search($classes[$index], $this->labels, true);
135
136
            foreach ($row as $col => $val) {
137
                if (!isset($means[$label][$col])) {
138
                    $means[$label][$col] = 0.0;
139
                }
140
141
                $means[$label][$col] += $val;
142
                $overallMean[$col] += $val;
143
            }
144
145
            if (!isset($counts[$label])) {
146
                $counts[$label] = 0;
147
            }
148
149
            ++$counts[$label];
150
        }
151
152
        foreach ($means as $index => $row) {
153
            foreach ($row as $col => $sum) {
154
                $means[$index][$col] = $sum / $counts[$index];
155
            }
156
        }
157
158
        // Calculate overall mean of the dataset for each column
159
        $numElements = array_sum($counts);
160
        $map = function ($el) use ($numElements) {
161
            return $el / $numElements;
162
        };
163
        $this->overallMean = array_map($map, $overallMean);
164
        $this->counts = $counts;
165
166
        return $means;
167
    }
168
169
    /**
170
     * Returns in-class scatter matrix for each class, which
171
     * is a n by m matrix where n is number of classes and
172
     * m is number of columns
173
     */
174
    protected function calculateClassVar(array $data, array $classes): Matrix
175
    {
176
        // s is an n (number of classes) by m (number of column) matrix
177
        $s = array_fill(0, count($data[0]), array_fill(0, count($data[0]), 0));
178
        $sW = new Matrix($s, false);
179
180
        foreach ($data as $index => $row) {
181
            $label = array_search($classes[$index], $this->labels, true);
182
            $means = $this->means[$label];
183
184
            $row = $this->calculateVar($row, $means);
185
186
            $sW = $sW->add($row);
187
        }
188
189
        return $sW;
190
    }
191
192
    /**
193
     * Returns between-class scatter matrix for each class, which
194
     * is an n by m matrix where n is number of classes and
195
     * m is number of columns
196
     */
197
    protected function calculateClassCov(): Matrix
198
    {
199
        // s is an n (number of classes) by m (number of column) matrix
200
        $s = array_fill(0, count($this->overallMean), array_fill(0, count($this->overallMean), 0));
201
        $sB = new Matrix($s, false);
202
203
        foreach ($this->means as $index => $classMeans) {
204
            $row = $this->calculateVar($classMeans, $this->overallMean);
205
            $N = $this->counts[$index];
206
            $sB = $sB->add($row->multiplyByScalar($N));
207
        }
208
209
        return $sB;
210
    }
211
212
    /**
213
     * Returns the result of the calculation (x - m)T.(x - m)
214
     */
215
    protected function calculateVar(array $row, array $means): Matrix
216
    {
217
        $x = new Matrix($row, false);
218
        $m = new Matrix($means, false);
219
        $diff = $x->subtract($m);
220
221
        return $diff->transpose()->multiply($diff);
222
    }
223
}
224