Passed
Push — master ( ba7114...554c86 )
by Arkadiusz
03:08
created

ClassificationReport::computeMacroAverage()   A

Complexity

Conditions 3
Paths 3

Size

Total Lines 13
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 13
rs 9.4285
c 0
b 0
f 0
cc 3
eloc 7
nc 3
nop 0
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Metric;
6
7
use Phpml\Exception\InvalidArgumentException;
8
9
class ClassificationReport
10
{
11
    public const MICRO_AVERAGE = 1;
12
13
    public const MACRO_AVERAGE = 2;
14
15
    public const WEIGHTED_AVERAGE = 3;
16
17
    /**
18
     * @var array
19
     */
20
    private $truePositive = [];
21
22
    /**
23
     * @var array
24
     */
25
    private $falsePositive = [];
26
27
    /**
28
     * @var array
29
     */
30
    private $falseNegative = [];
31
32
    /**
33
     * @var array
34
     */
35
    private $support = [];
36
37
    /**
38
     * @var array
39
     */
40
    private $precision = [];
41
42
    /**
43
     * @var array
44
     */
45
    private $recall = [];
46
47
    /**
48
     * @var array
49
     */
50
    private $f1score = [];
51
52
    /**
53
     * @var array
54
     */
55
    private $average = [];
56
57
    public function __construct(array $actualLabels, array $predictedLabels, int $average = self::MACRO_AVERAGE)
58
    {
59
        $averagingMethods = range(self::MICRO_AVERAGE, self::WEIGHTED_AVERAGE);
60
        if (!in_array($average, $averagingMethods)) {
61
            throw new InvalidArgumentException('Averaging method must be MICRO_AVERAGE, MACRO_AVERAGE or WEIGHTED_AVERAGE');
62
        }
63
64
        $this->aggregateClassificationResults($actualLabels, $predictedLabels);
65
        $this->computeMetrics();
66
        $this->computeAverage($average);
67
    }
68
69
    public function getPrecision(): array
70
    {
71
        return $this->precision;
72
    }
73
74
    public function getRecall(): array
75
    {
76
        return $this->recall;
77
    }
78
79
    public function getF1score(): array
80
    {
81
        return $this->f1score;
82
    }
83
84
    public function getSupport(): array
85
    {
86
        return $this->support;
87
    }
88
89
    public function getAverage(): array
90
    {
91
        return $this->average;
92
    }
93
94
    private function aggregateClassificationResults(array $actualLabels, array $predictedLabels): void
95
    {
96
        $truePositive = $falsePositive = $falseNegative = $support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
97
98
        foreach ($actualLabels as $index => $actual) {
99
            $predicted = $predictedLabels[$index];
100
            ++$support[$actual];
101
102
            if ($actual === $predicted) {
103
                ++$truePositive[$actual];
104
            } else {
105
                ++$falsePositive[$predicted];
106
                ++$falseNegative[$actual];
107
            }
108
        }
109
110
        $this->truePositive = $truePositive;
111
        $this->falsePositive = $falsePositive;
112
        $this->falseNegative = $falseNegative;
113
        $this->support = $support;
114
    }
115
116
    private function computeMetrics(): void
117
    {
118
        foreach ($this->truePositive as $label => $tp) {
119
            $this->precision[$label] = $this->computePrecision($tp, $this->falsePositive[$label]);
120
            $this->recall[$label] = $this->computeRecall($tp, $this->falseNegative[$label]);
121
            $this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
122
        }
123
    }
124
125
    private function computeAverage(int $average): void
126
    {
127
        switch ($average) {
128
            case self::MICRO_AVERAGE:
129
                $this->computeMicroAverage();
130
131
                return;
132
            case self::MACRO_AVERAGE:
133
                $this->computeMacroAverage();
134
135
                return;
136
            case self::WEIGHTED_AVERAGE:
137
                $this->computeWeightedAverage();
138
139
                return;
140
        }
141
    }
142
143
    private function computeMicroAverage(): void
144
    {
145
        $truePositive = array_sum($this->truePositive);
146
        $falsePositive = array_sum($this->falsePositive);
147
        $falseNegative = array_sum($this->falseNegative);
148
149
        $precision = $this->computePrecision($truePositive, $falsePositive);
150
        $recall = $this->computeRecall($truePositive, $falseNegative);
151
        $f1score = $this->computeF1Score((float) $precision, (float) $recall);
152
153
        $this->average = compact('precision', 'recall', 'f1score');
154
    }
155
156
    private function computeMacroAverage(): void
157
    {
158
        foreach (['precision', 'recall', 'f1score'] as $metric) {
159
            $values = $this->{$metric};
160
            if (count($values) == 0) {
161
                $this->average[$metric] = 0.0;
162
163
                continue;
164
            }
165
166
            $this->average[$metric] = array_sum($values) / count($values);
167
        }
168
    }
169
170
    private function computeWeightedAverage(): void
171
    {
172
        foreach (['precision', 'recall', 'f1score'] as $metric) {
173
            $values = $this->{$metric};
174
            if (count($values) == 0) {
175
                $this->average[$metric] = 0.0;
176
177
                continue;
178
            }
179
180
            $sum = 0;
181
            foreach ($values as $i => $value) {
182
                $sum += $value * $this->support[$i];
183
            }
184
185
            $this->average[$metric] = $sum / array_sum($this->support);
186
        }
187
    }
188
189
    /**
190
     * @return float|string
191
     */
192
    private function computePrecision(int $truePositive, int $falsePositive)
193
    {
194
        $divider = $truePositive + $falsePositive;
195
        if ($divider == 0) {
196
            return 0.0;
197
        }
198
199
        return $truePositive / $divider;
200
    }
201
202
    /**
203
     * @return float|string
204
     */
205
    private function computeRecall(int $truePositive, int $falseNegative)
206
    {
207
        $divider = $truePositive + $falseNegative;
208
        if ($divider == 0) {
209
            return 0.0;
210
        }
211
212
        return $truePositive / $divider;
213
    }
214
215
    private function computeF1Score(float $precision, float $recall): float
216
    {
217
        $divider = $precision + $recall;
218
        if ($divider == 0) {
219
            return 0.0;
220
        }
221
222
        return 2.0 * (($precision * $recall) / $divider);
223
    }
224
225
    private static function getLabelIndexedArray(array $actualLabels, array $predictedLabels): array
226
    {
227
        $labels = array_values(array_unique(array_merge($actualLabels, $predictedLabels)));
228
        sort($labels);
229
        $labels = array_combine($labels, array_fill(0, count($labels), 0));
230
231
        return $labels;
232
    }
233
}
234