Completed
Push — develop ( 76d15e...963cfe )
by Arkadiusz
03:20
created

ClassificationReport::getRecall()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 4
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
eloc 2
nc 1
nop 0
1
<?php
2
3
declare (strict_types = 1);
4
5
namespace Phpml\Metric;
6
7
class ClassificationReport
8
{
9
    /**
10
     * @var array
11
     */
12
    private $precision = [];
13
14
    /**
15
     * @var array
16
     */
17
    private $recall = [];
18
19
    /**
20
     * @var array
21
     */
22
    private $f1score = [];
23
24
    /**
25
     * @var array
26
     */
27
    private $support = [];
28
29
    /**
30
     * @var array
31
     */
32
    private $average = [];
33
34
    /**
35
     * @param array $actualLabels
36
     * @param array $predictedLabels
37
     */
38
    public function __construct(array $actualLabels, array $predictedLabels)
39
    {
40
        $truePositive = $falsePositive = $falseNegative = $this->support = self::getLabelIndexedArray($actualLabels);
41
42
        foreach ($actualLabels as $index => $actual) {
43
            $predicted = $predictedLabels[$index];
44
            ++$this->support[$actual];
45
46
            if ($actual === $predicted) {
47
                ++$truePositive[$actual];
48
            } else {
49
                ++$falsePositive[$predicted];
50
                ++$falseNegative[$actual];
51
            }
52
        }
53
54
        $this->computeMetrics($truePositive, $falsePositive, $falseNegative);
55
        $this->computeAverage();
56
    }
57
58
    /**
59
     * @return array
60
     */
61
    public function getPrecision()
62
    {
63
        return $this->precision;
64
    }
65
66
    /**
67
     * @return array
68
     */
69
    public function getRecall()
70
    {
71
        return $this->recall;
72
    }
73
74
    /**
75
     * @return array
76
     */
77
    public function getF1score()
78
    {
79
        return $this->f1score;
80
    }
81
82
    /**
83
     * @return array
84
     */
85
    public function getSupport()
86
    {
87
        return $this->support;
88
    }
89
90
    /**
91
     * @return array
92
     */
93
    public function getAverage()
94
    {
95
        return $this->average;
96
    }
97
98
    /**
99
     * @param array $truePositive
100
     * @param array $falsePositive
101
     * @param array $falseNegative
102
     */
103
    private function computeMetrics(array $truePositive, array $falsePositive, array $falseNegative)
104
    {
105
        foreach ($truePositive as $label => $tp) {
106
            $this->precision[$label] = $tp / ($tp + $falsePositive[$label]);
107
            $this->recall[$label] = $tp / ($tp + $falseNegative[$label]);
108
            $this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
109
        }
110
    }
111
112
    private function computeAverage()
113
    {
114
        foreach (['precision', 'recall', 'f1score'] as $metric) {
115
            $values = array_filter($this->$metric);
116
            $this->average[$metric] = array_sum($values) / count($values);
117
        }
118
    }
119
120
    /**
121
     * @param float $precision
122
     * @param float $recall
123
     *
124
     * @return float
125
     */
126
    private function computeF1Score(float $precision, float $recall): float
127
    {
128
        if (0 == ($divider = $precision + $recall)) {
129
            return 0.0;
130
        }
131
132
        return 2.0 * (($precision * $recall) / ($divider));
133
    }
134
135
    /**
136
     * @param array $labels
137
     *
138
     * @return array
139
     */
140
    private static function getLabelIndexedArray(array $labels): array
141
    {
142
        $labels = array_values(array_unique($labels));
143
        sort($labels);
144
        $labels = array_combine($labels, array_fill(0, count($labels), 0));
145
146
        return $labels;
147
    }
148
}
149