Passed
Pull Request — master (#205)
by Yuji
02:32
created

ClassificationReport::computeMacroAverage()   A

Complexity

Conditions 3
Paths 3

Size

Total Lines 13
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 13
rs 9.4285
c 0
b 0
f 0
cc 3
eloc 7
nc 3
nop 0
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Metric;
6
7
use Phpml\Exception\InvalidArgumentException;
8
9
class ClassificationReport
10
{
11
    public const MICRO_AVERAGE = 1;
12
13
    public const MACRO_AVERAGE = 2;
14
15
    public const WEIGHTED_AVERAGE = 3;
16
17
    /**
18
     * @var array
19
     */
20
    private $truePositive = [];
21
22
    /**
23
     * @var array
24
     */
25
    private $falsePositive = [];
26
27
    /**
28
     * @var array
29
     */
30
    private $falseNegative = [];
31
32
    /**
33
     * @var array
34
     */
35
    private $support = [];
36
37
    /**
38
     * @var array
39
     */
40
    private $precision = [];
41
42
    /**
43
     * @var array
44
     */
45
    private $recall = [];
46
47
    /**
48
     * @var array
49
     */
50
    private $f1score = [];
51
52
    /**
53
     * @var array
54
     */
55
    private $average = [];
56
57
    public function __construct(array $actualLabels, array $predictedLabels, int $average = self::MACRO_AVERAGE)
58
    {
59
        $averagingMethods = range(self::MICRO_AVERAGE, self::WEIGHTED_AVERAGE);
60
        if (!in_array($average, $averagingMethods)) {
61
            throw new InvalidArgumentException('Averaging method must be MICRO_AVERAGE, MACRO_AVERAGE or WEIGHTED_AVERAGE');
62
        }
63
64
        $this->aggregateClassificationResults($actualLabels, $predictedLabels);
65
        $this->computeMetrics();
66
        $this->computeAverage($average);
67
    }
68
69
    public function getPrecision(): array
70
    {
71
        return $this->precision;
72
    }
73
74
    public function getRecall(): array
75
    {
76
        return $this->recall;
77
    }
78
79
    public function getF1score(): array
80
    {
81
        return $this->f1score;
82
    }
83
84
    public function getSupport(): array
85
    {
86
        return $this->support;
87
    }
88
89
    public function getAverage(): array
90
    {
91
        return $this->average;
92
    }
93
94
    private function aggregateClassificationResults(array $actualLabels, array $predictedLabels): void
95
    {
96
        $truePositive = $falsePositive = $falseNegative = $support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
97
98
        foreach ($actualLabels as $index => $actual) {
99
            $predicted = $predictedLabels[$index];
100
            ++$support[$actual];
101
102
            if ($actual === $predicted) {
103
                ++$truePositive[$actual];
104
            } else {
105
                ++$falsePositive[$predicted];
106
                ++$falseNegative[$actual];
107
            }
108
        }
109
110
        $this->truePositive = $truePositive;
111
        $this->falsePositive = $falsePositive;
112
        $this->falseNegative = $falseNegative;
113
        $this->support = $support;
114
    }
115
116
    private function computeMetrics(): void
117
    {
118
        foreach ($this->truePositive as $label => $tp) {
119
            $this->precision[$label] = $this->computePrecision($tp, $this->falsePositive[$label]);
120
            $this->recall[$label] = $this->computeRecall($tp, $this->falseNegative[$label]);
121
            $this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
122
        }
123
    }
124
125
    private function computeAverage(int $average): void
126
    {
127
        switch ($average) {
128
            case self::MICRO_AVERAGE:
129
                $this->computeMicroAverage();
130
131
                return;
132
            case self::MACRO_AVERAGE:
133
                $this->computeMacroAverage();
134
135
                return;
136
            case self::WEIGHTED_AVERAGE:
137
                $this->computeWeightedAverage();
138
139
                return;
140
        }
141
    }
142
143
    private function computeMicroAverage(): void
144
    {
145
        foreach (['truePositive', 'falsePositive', 'falseNegative'] as $aggregate) {
146
            ${$aggregate} = array_sum($this->{$aggregate});
147
        }
148
149
        $precision = $this->computePrecision($truePositive, $falsePositive);
0 ignored issues
show
Bug introduced by
The variable $truePositive does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
Bug introduced by
The variable $falsePositive does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
150
        $recall = $this->computeRecall($truePositive, $falseNegative);
0 ignored issues
show
Bug introduced by
The variable $falseNegative does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
151
        $f1score = $this->computeF1Score((float) $precision, (float) $recall);
152
153
        $this->average = compact('precision', 'recall', 'f1score');
154
    }
155
156
    private function computeMacroAverage(): void
157
    {
158
        foreach (['precision', 'recall', 'f1score'] as $metric) {
159
            $values = $this->{$metric};
160
            if (count($values) == 0) {
161
                $this->average[$metric] = 0.0;
162
163
                continue;
164
            }
165
166
            $this->average[$metric] = array_sum($values) / count($values);
167
        }
168
    }
169
170
    private function computeWeightedAverage(): void
171
    {
172
        foreach (['precision', 'recall', 'f1score'] as $metric) {
173
            $values = $this->{$metric};
174
            if (count($values) == 0) {
175
                $this->average[$metric] = 0.0;
176
177
                continue;
178
            }
179
180
            $sum = 0;
181
            foreach ($values as $i => $value) {
182
                $sum += $value * $this->support[$i];
183
            }
184
185
            $this->average[$metric] = $sum / array_sum($this->support);
186
        }
187
    }
188
189
    /**
190
     * @return float|string
191
     */
192
    private function computePrecision(int $truePositive, int $falsePositive)
193
    {
194
        $divider = $truePositive + $falsePositive;
195
        if ($divider == 0) {
196
            return 0.0;
197
        }
198
199
        return $truePositive / $divider;
200
    }
201
202
    /**
203
     * @return float|string
204
     */
205
    private function computeRecall(int $truePositive, int $falseNegative)
206
    {
207
        $divider = $truePositive + $falseNegative;
208
        if ($divider == 0) {
209
            return 0.0;
210
        }
211
212
        return $truePositive / $divider;
213
    }
214
215
    private function computeF1Score(float $precision, float $recall): float
216
    {
217
        $divider = $precision + $recall;
218
        if ($divider == 0) {
219
            return 0.0;
220
        }
221
222
        return 2.0 * (($precision * $recall) / $divider);
223
    }
224
225
    private static function getLabelIndexedArray(array $actualLabels, array $predictedLabels): array
226
    {
227
        $labels = array_values(array_unique(array_merge($actualLabels, $predictedLabels)));
228
        sort($labels);
229
        $labels = array_combine($labels, array_fill(0, count($labels), 0));
230
231
        return $labels;
232
    }
233
}
234