Passed
Push — master ( e83f7b...d953ef )
by Arkadiusz
03:28
created

src/Phpml/Helper/OneVsRest.php (2 issues)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper;
6
7
use Phpml\Classification\Classifier;
8
9
trait OneVsRest
10
{
11
    /**
12
     * @var array
13
     */
14
    protected $classifiers = [];
15
16
    /**
17
     * All provided training targets' labels.
18
     *
19
     * @var array
20
     */
21
    protected $allLabels = [];
22
23
    /**
24
     * @var array
25
     */
26
    protected $costValues = [];
27
28
    /**
29
     * Train a binary classifier in the OvR style
30
     */
31
    public function train(array $samples, array $targets): void
32
    {
33
        // Clears previous stuff.
34
        $this->reset();
35
36
        $this->trainByLabel($samples, $targets);
37
    }
38
39
    /**
40
     * Resets the classifier and the vars internally used by OneVsRest to create multiple classifiers.
41
     */
42
    public function reset(): void
43
    {
44
        $this->classifiers = [];
45
        $this->allLabels = [];
46
        $this->costValues = [];
47
48
        $this->resetBinary();
49
    }
50
51
    protected function trainByLabel(array $samples, array $targets, array $allLabels = []): void
52
    {
53
        // Overwrites the current value if it exist. $allLabels must be provided for each partialTrain run.
54
        if (!empty($allLabels)) {
55
            $this->allLabels = $allLabels;
56
        } else {
57
            $this->allLabels = array_keys(array_count_values($targets));
58
        }
59
60
        sort($this->allLabels, SORT_STRING);
61
62
        // If there are only two targets, then there is no need to perform OvR
63
        if (count($this->allLabels) == 2) {
64
            // Init classifier if required.
65
            if (empty($this->classifiers)) {
66
                $this->classifiers[0] = $this->getClassifierCopy();
67
            }
68
69
            $this->classifiers[0]->trainBinary($samples, $targets, $this->allLabels);
70
        } else {
71
            // Train a separate classifier for each label and memorize them
72
73
            foreach ($this->allLabels as $label) {
74
                // Init classifier if required.
75
                if (empty($this->classifiers[$label])) {
76
                    $this->classifiers[$label] = $this->getClassifierCopy();
77
                }
78
79
                [$binarizedTargets, $classifierLabels] = $this->binarizeTargets($targets, $label);
0 ignored issues
show
The variable $binarizedTargets does not exist. Did you mean $targets?

This check looks for variables that are accessed but have not been defined. It raises an issue if it finds another variable that has a similar name.

The variable may have been renamed without also renaming all references.

Loading history...
The variable $classifierLabels does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
80
                $this->classifiers[$label]->trainBinary($samples, $binarizedTargets, $classifierLabels);
81
            }
82
        }
83
84
        // If the underlying classifier is capable of giving the cost values
85
        // during the training, then assign it to the relevant variable
86
        // Adding just the first classifier cost values to avoid complex average calculations.
87
        $classifierref = reset($this->classifiers);
88
        if (method_exists($classifierref, 'getCostValues')) {
89
            $this->costValues = $classifierref->getCostValues();
90
        }
91
    }
92
93
    /**
94
     * Returns an instance of the current class after cleaning up OneVsRest stuff.
95
     *
96
     * @return Classifier|OneVsRest
97
     */
98
    protected function getClassifierCopy()
99
    {
100
        // Clone the current classifier, so that
101
        // we don't mess up its variables while training
102
        // multiple instances of this classifier
103
        $classifier = clone $this;
104
        $classifier->reset();
105
106
        return $classifier;
107
    }
108
109
    /**
110
     * @return mixed
111
     */
112
    protected function predictSample(array $sample)
113
    {
114
        if (count($this->allLabels) == 2) {
115
            return $this->classifiers[0]->predictSampleBinary($sample);
116
        }
117
118
        $probs = [];
119
120
        foreach ($this->classifiers as $label => $predictor) {
121
            $probs[$label] = $predictor->predictProbability($sample, $label);
122
        }
123
124
        arsort($probs, SORT_NUMERIC);
125
126
        return key($probs);
127
    }
128
129
    /**
130
     * Each classifier should implement this method instead of train(samples, targets)
131
     */
132
    abstract protected function trainBinary(array $samples, array $targets, array $labels);
133
134
    /**
135
     * To be overwritten by OneVsRest classifiers.
136
     */
137
    abstract protected function resetBinary(): void;
138
139
    /**
140
     * Each classifier that make use of OvR approach should be able to
141
     * return a probability for a sample to belong to the given label.
142
     *
143
     * @return mixed
144
     */
145
    abstract protected function predictProbability(array $sample, string $label);
146
147
    /**
148
     * Each classifier should implement this method instead of predictSample()
149
     *
150
     * @return mixed
151
     */
152
    abstract protected function predictSampleBinary(array $sample);
153
154
    /**
155
     * Groups all targets into two groups: Targets equal to
156
     * the given label and the others
157
     *
158
     * $targets is not passed by reference nor contains objects so this method
159
     * changes will not affect the caller $targets array.
160
     *
161
     * @param mixed $label
162
     *
163
     * @return array Binarized targets and target's labels
164
     */
165
    private function binarizeTargets(array $targets, $label): array
166
    {
167
        $notLabel = "not_${label}";
168
        foreach ($targets as $key => $target) {
169
            $targets[$key] = $target == $label ? $label : $notLabel;
170
        }
171
172
        $labels = [$label, $notLabel];
173
174
        return [$targets, $labels];
175
    }
176
}
177