ClassificationReport   A
last analyzed

Complexity

Total Complexity 31

Size/Duplication

Total Lines 222
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 31
eloc 87
dl 0
loc 222
rs 9.92
c 0
b 0
f 0

16 Methods

Rating   Name   Duplication   Size   Complexity  
A getRecall() 0 3 1
A getSupport() 0 3 1
A computeMacroAverage() 0 11 3
A computePrecision() 0 8 2
A getF1score() 0 3 1
A getLabelIndexedArray() 0 6 1
A computeRecall() 0 8 2
A computeWeightedAverage() 0 16 4
A __construct() 0 10 2
A aggregateClassificationResults() 0 20 3
A computeF1Score() 0 8 2
A computeAverage() 0 15 4
A computeMetrics() 0 6 2
A computeMicroAverage() 0 11 1
A getAverage() 0 3 1
A getPrecision() 0 3 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Metric;
6
7
use Phpml\Exception\InvalidArgumentException;
8
9
class ClassificationReport
10
{
11
    public const MICRO_AVERAGE = 1;
12
13
    public const MACRO_AVERAGE = 2;
14
15
    public const WEIGHTED_AVERAGE = 3;
16
17
    /**
18
     * @var array
19
     */
20
    private $truePositive = [];
21
22
    /**
23
     * @var array
24
     */
25
    private $falsePositive = [];
26
27
    /**
28
     * @var array
29
     */
30
    private $falseNegative = [];
31
32
    /**
33
     * @var array
34
     */
35
    private $support = [];
36
37
    /**
38
     * @var array
39
     */
40
    private $precision = [];
41
42
    /**
43
     * @var array
44
     */
45
    private $recall = [];
46
47
    /**
48
     * @var array
49
     */
50
    private $f1score = [];
51
52
    /**
53
     * @var array
54
     */
55
    private $average = [];
56
57
    public function __construct(array $actualLabels, array $predictedLabels, int $average = self::MACRO_AVERAGE)
58
    {
59
        $averagingMethods = range(self::MICRO_AVERAGE, self::WEIGHTED_AVERAGE);
60
        if (!in_array($average, $averagingMethods, true)) {
61
            throw new InvalidArgumentException('Averaging method must be MICRO_AVERAGE, MACRO_AVERAGE or WEIGHTED_AVERAGE');
62
        }
63
64
        $this->aggregateClassificationResults($actualLabels, $predictedLabels);
65
        $this->computeMetrics();
66
        $this->computeAverage($average);
67
    }
68
69
    public function getPrecision(): array
70
    {
71
        return $this->precision;
72
    }
73
74
    public function getRecall(): array
75
    {
76
        return $this->recall;
77
    }
78
79
    public function getF1score(): array
80
    {
81
        return $this->f1score;
82
    }
83
84
    public function getSupport(): array
85
    {
86
        return $this->support;
87
    }
88
89
    public function getAverage(): array
90
    {
91
        return $this->average;
92
    }
93
94
    private function aggregateClassificationResults(array $actualLabels, array $predictedLabels): void
95
    {
96
        $truePositive = $falsePositive = $falseNegative = $support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
97
98
        foreach ($actualLabels as $index => $actual) {
99
            $predicted = $predictedLabels[$index];
100
            ++$support[$actual];
101
102
            if ($actual === $predicted) {
103
                ++$truePositive[$actual];
104
            } else {
105
                ++$falsePositive[$predicted];
106
                ++$falseNegative[$actual];
107
            }
108
        }
109
110
        $this->truePositive = $truePositive;
111
        $this->falsePositive = $falsePositive;
112
        $this->falseNegative = $falseNegative;
113
        $this->support = $support;
114
    }
115
116
    private function computeMetrics(): void
117
    {
118
        foreach ($this->truePositive as $label => $tp) {
119
            $this->precision[$label] = $this->computePrecision($tp, $this->falsePositive[$label]);
120
            $this->recall[$label] = $this->computeRecall($tp, $this->falseNegative[$label]);
121
            $this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
122
        }
123
    }
124
125
    private function computeAverage(int $average): void
126
    {
127
        switch ($average) {
128
            case self::MICRO_AVERAGE:
129
                $this->computeMicroAverage();
130
131
                return;
132
            case self::MACRO_AVERAGE:
133
                $this->computeMacroAverage();
134
135
                return;
136
            case self::WEIGHTED_AVERAGE:
137
                $this->computeWeightedAverage();
138
139
                return;
140
        }
141
    }
142
143
    private function computeMicroAverage(): void
144
    {
145
        $truePositive = (int) array_sum($this->truePositive);
146
        $falsePositive = (int) array_sum($this->falsePositive);
147
        $falseNegative = (int) array_sum($this->falseNegative);
148
149
        $precision = $this->computePrecision($truePositive, $falsePositive);
150
        $recall = $this->computeRecall($truePositive, $falseNegative);
151
        $f1score = $this->computeF1Score((float) $precision, (float) $recall);
152
153
        $this->average = compact('precision', 'recall', 'f1score');
154
    }
155
156
    private function computeMacroAverage(): void
157
    {
158
        foreach (['precision', 'recall', 'f1score'] as $metric) {
159
            $values = $this->{$metric};
160
            if (count($values) == 0) {
161
                $this->average[$metric] = 0.0;
162
163
                continue;
164
            }
165
166
            $this->average[$metric] = array_sum($values) / count($values);
167
        }
168
    }
169
170
    private function computeWeightedAverage(): void
171
    {
172
        foreach (['precision', 'recall', 'f1score'] as $metric) {
173
            $values = $this->{$metric};
174
            if (count($values) == 0) {
175
                $this->average[$metric] = 0.0;
176
177
                continue;
178
            }
179
180
            $sum = 0;
181
            foreach ($values as $i => $value) {
182
                $sum += $value * $this->support[$i];
183
            }
184
185
            $this->average[$metric] = $sum / array_sum($this->support);
186
        }
187
    }
188
189
    /**
190
     * @return float|string
191
     */
192
    private function computePrecision(int $truePositive, int $falsePositive)
193
    {
194
        $divider = $truePositive + $falsePositive;
195
        if ($divider == 0) {
196
            return 0.0;
197
        }
198
199
        return $truePositive / $divider;
200
    }
201
202
    /**
203
     * @return float|string
204
     */
205
    private function computeRecall(int $truePositive, int $falseNegative)
206
    {
207
        $divider = $truePositive + $falseNegative;
208
        if ($divider == 0) {
209
            return 0.0;
210
        }
211
212
        return $truePositive / $divider;
213
    }
214
215
    private function computeF1Score(float $precision, float $recall): float
216
    {
217
        $divider = $precision + $recall;
218
        if ($divider == 0) {
219
            return 0.0;
220
        }
221
222
        return 2.0 * (($precision * $recall) / $divider);
223
    }
224
225
    private static function getLabelIndexedArray(array $actualLabels, array $predictedLabels): array
226
    {
227
        $labels = array_values(array_unique(array_merge($actualLabels, $predictedLabels)));
228
        sort($labels);
229
230
        return (array) array_combine($labels, array_fill(0, count($labels), 0));
231
    }
232
}
233