Passed
Push — master ( e83f7b...d953ef )
by Arkadiusz
03:28
created

LDA::fit()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 15
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 15
rs 9.4285
c 0
b 0
f 0
cc 1
eloc 9
nc 1
nop 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\DimensionReduction;
6
7
use Exception;
8
use Phpml\Math\Matrix;
9
10
class LDA extends EigenTransformerBase
11
{
12
    /**
13
     * @var bool
14
     */
15
    public $fit = false;
16
17
    /**
18
     * @var array
19
     */
20
    public $labels = [];
21
22
    /**
23
     * @var array
24
     */
25
    public $means = [];
26
27
    /**
28
     * @var array
29
     */
30
    public $counts = [];
31
32
    /**
33
     * @var float[]
34
     */
35
    public $overallMean = [];
36
37
    /**
38
     * Linear Discriminant Analysis (LDA) is used to reduce the dimensionality
39
     * of the data. Unlike Principal Component Analysis (PCA), it is a supervised
40
     * technique that requires the class labels in order to fit the data to a
41
     * lower dimensional space. <br><br>
42
     * The algorithm can be initialized by speciyfing
43
     * either with the totalVariance(a value between 0.1 and 0.99)
44
     * or numFeatures (number of features in the dataset) to be preserved.
45
     *
46
     * @param float|null $totalVariance Total explained variance to be preserved
47
     * @param int|null   $numFeatures   Number of features to be preserved
48
     *
49
     * @throws \Exception
50
     */
51
    public function __construct(?float $totalVariance = null, ?int $numFeatures = null)
52
    {
53
        if ($totalVariance !== null && ($totalVariance < 0.1 || $totalVariance > 0.99)) {
54
            throw new Exception('Total variance can be a value between 0.1 and 0.99');
55
        }
56
57
        if ($numFeatures !== null && $numFeatures <= 0) {
58
            throw new Exception('Number of features to be preserved should be greater than 0');
59
        }
60
61
        if ($totalVariance !== null && $numFeatures !== null) {
62
            throw new Exception('Either totalVariance or numFeatures should be specified in order to run the algorithm');
63
        }
64
65
        if ($numFeatures !== null) {
66
            $this->numFeatures = $numFeatures;
67
        }
68
69
        if ($totalVariance !== null) {
70
            $this->totalVariance = $totalVariance;
71
        }
72
    }
73
74
    /**
75
     * Trains the algorithm to transform the given data to a lower dimensional space.
76
     */
77
    public function fit(array $data, array $classes): array
78
    {
79
        $this->labels = $this->getLabels($classes);
80
        $this->means = $this->calculateMeans($data, $classes);
81
82
        $sW = $this->calculateClassVar($data, $classes);
83
        $sB = $this->calculateClassCov();
84
85
        $S = $sW->inverse()->multiply($sB);
86
        $this->eigenDecomposition($S->toArray());
87
88
        $this->fit = true;
89
90
        return $this->reduce($data);
91
    }
92
93
    /**
94
     * Transforms the given sample to a lower dimensional vector by using
95
     * the eigenVectors obtained in the last run of <code>fit</code>.
96
     *
97
     * @throws \Exception
98
     */
99
    public function transform(array $sample): array
100
    {
101
        if (!$this->fit) {
102
            throw new Exception('LDA has not been fitted with respect to original dataset, please run LDA::fit() first');
103
        }
104
105
        if (!is_array($sample[0])) {
106
            $sample = [$sample];
107
        }
108
109
        return $this->reduce($sample);
110
    }
111
112
    /**
113
     * Returns unique labels in the dataset
114
     */
115
    protected function getLabels(array $classes): array
116
    {
117
        $counts = array_count_values($classes);
118
119
        return array_keys($counts);
120
    }
121
122
    /**
123
     * Calculates mean of each column for each class and returns
124
     * n by m matrix where n is number of labels and m is number of columns
125
     */
126
    protected function calculateMeans(array $data, array $classes): array
127
    {
128
        $means = [];
129
        $counts = [];
130
        $overallMean = array_fill(0, count($data[0]), 0.0);
131
132
        foreach ($data as $index => $row) {
133
            $label = array_search($classes[$index], $this->labels);
134
135
            foreach ($row as $col => $val) {
136
                if (!isset($means[$label][$col])) {
137
                    $means[$label][$col] = 0.0;
138
                }
139
140
                $means[$label][$col] += $val;
141
                $overallMean[$col] += $val;
142
            }
143
144
            if (!isset($counts[$label])) {
145
                $counts[$label] = 0;
146
            }
147
148
            ++$counts[$label];
149
        }
150
151
        foreach ($means as $index => $row) {
152
            foreach ($row as $col => $sum) {
153
                $means[$index][$col] = $sum / $counts[$index];
154
            }
155
        }
156
157
        // Calculate overall mean of the dataset for each column
158
        $numElements = array_sum($counts);
159
        $map = function ($el) use ($numElements) {
160
            return $el / $numElements;
161
        };
162
        $this->overallMean = array_map($map, $overallMean);
163
        $this->counts = $counts;
164
165
        return $means;
166
    }
167
168
    /**
169
     * Returns in-class scatter matrix for each class, which
170
     * is a n by m matrix where n is number of classes and
171
     * m is number of columns
172
     */
173
    protected function calculateClassVar(array $data, array $classes): Matrix
174
    {
175
        // s is an n (number of classes) by m (number of column) matrix
176
        $s = array_fill(0, count($data[0]), array_fill(0, count($data[0]), 0));
177
        $sW = new Matrix($s, false);
178
179
        foreach ($data as $index => $row) {
180
            $label = array_search($classes[$index], $this->labels);
181
            $means = $this->means[$label];
182
183
            $row = $this->calculateVar($row, $means);
184
185
            $sW = $sW->add($row);
186
        }
187
188
        return $sW;
189
    }
190
191
    /**
192
     * Returns between-class scatter matrix for each class, which
193
     * is an n by m matrix where n is number of classes and
194
     * m is number of columns
195
     */
196
    protected function calculateClassCov(): Matrix
197
    {
198
        // s is an n (number of classes) by m (number of column) matrix
199
        $s = array_fill(0, count($this->overallMean), array_fill(0, count($this->overallMean), 0));
200
        $sB = new Matrix($s, false);
201
202
        foreach ($this->means as $index => $classMeans) {
203
            $row = $this->calculateVar($classMeans, $this->overallMean);
204
            $N = $this->counts[$index];
205
            $sB = $sB->add($row->multiplyByScalar($N));
206
        }
207
208
        return $sB;
209
    }
210
211
    /**
212
     * Returns the result of the calculation (x - m)T.(x - m)
213
     */
214
    protected function calculateVar(array $row, array $means): Matrix
215
    {
216
        $x = new Matrix($row, false);
217
        $m = new Matrix($means, false);
218
        $diff = $x->subtract($m);
0 ignored issues
show
Documentation introduced by Mustafa Karabulut
$m is of type object<Phpml\Math\Matrix>, but the function expects a object<self>.

It seems like the type of the argument is not accepted by the function/method which you are calling.

In some cases, in particular if PHP’s automatic type-juggling kicks in this might be fine. In other cases, however this might be a bug.

We suggest to add an explicit type cast like in the following example:

function acceptsInteger($int) { }

$x = '123'; // string "123"

// Instead of
acceptsInteger($x);

// we recommend to use
acceptsInteger((integer) $x);
Loading history...
219
220
        return $diff->transpose()->multiply($diff);
221
    }
222
}
223