Passed
Push — master ( 331d4b...653c7c )
by Arkadiusz
02:19
created

src/Phpml/Helper/OneVsRest.php (3 issues)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper;
6
7
use Phpml\Classification\Classifier;
8
9
trait OneVsRest
10
{
11
    /**
12
     * @var array
13
     */
14
    protected $classifiers = [];
15
16
    /**
17
     * All provided training targets' labels.
18
     *
19
     * @var array
20
     */
21
    protected $allLabels = [];
22
23
    /**
24
     * @var array
25
     */
26
    protected $costValues = [];
27
28
    /**
29
     * Train a binary classifier in the OvR style
30
     */
31
    public function train(array $samples, array $targets): void
32
    {
33
        // Clears previous stuff.
34
        $this->reset();
35
36
        $this->trainBylabel($samples, $targets);
37
    }
38
39
    protected function trainByLabel(array $samples, array $targets, array $allLabels = []): void
40
    {
41
        // Overwrites the current value if it exist. $allLabels must be provided for each partialTrain run.
42
        if (!empty($allLabels)) {
43
            $this->allLabels = $allLabels;
44
        } else {
45
            $this->allLabels = array_keys(array_count_values($targets));
46
        }
47
        sort($this->allLabels, SORT_STRING);
48
49
        // If there are only two targets, then there is no need to perform OvR
50
        if (count($this->allLabels) == 2) {
51
            // Init classifier if required.
52
            if (empty($this->classifiers)) {
53
                $this->classifiers[0] = $this->getClassifierCopy();
54
            }
55
56
            $this->classifiers[0]->trainBinary($samples, $targets, $this->allLabels);
57
        } else {
58
            // Train a separate classifier for each label and memorize them
59
60
            foreach ($this->allLabels as $label) {
61
                // Init classifier if required.
62
                if (empty($this->classifiers[$label])) {
63
                    $this->classifiers[$label] = $this->getClassifierCopy();
64
                }
65
66
                [$binarizedTargets, $classifierLabels] = $this->binarizeTargets($targets, $label);
0 ignored issues
show
The variable $binarizedTargets does not exist. Did you mean $targets?

This check looks for variables that are accessed but have not been defined. It raises an issue if it finds another variable that has a similar name.

The variable may have been renamed without also renaming all references.

Loading history...
The variable $classifierLabels does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
67
                $this->classifiers[$label]->trainBinary($samples, $binarizedTargets, $classifierLabels);
0 ignored issues
show
The variable $binarizedTargets does not exist. Did you mean $targets?

This check looks for variables that are accessed but have not been defined. It raises an issue if it finds another variable that has a similar name.

The variable may have been renamed without also renaming all references.

Loading history...
68
            }
69
        }
70
71
        // If the underlying classifier is capable of giving the cost values
72
        // during the training, then assign it to the relevant variable
73
        // Adding just the first classifier cost values to avoid complex average calculations.
74
        $classifierref = reset($this->classifiers);
75
        if (method_exists($classifierref, 'getCostValues')) {
76
            $this->costValues = $classifierref->getCostValues();
77
        }
78
    }
79
80
    /**
81
     * Resets the classifier and the vars internally used by OneVsRest to create multiple classifiers.
82
     */
83
    public function reset(): void
84
    {
85
        $this->classifiers = [];
86
        $this->allLabels = [];
87
        $this->costValues = [];
88
89
        $this->resetBinary();
90
    }
91
92
    /**
93
     * Returns an instance of the current class after cleaning up OneVsRest stuff.
94
     *
95
     * @return Classifier|OneVsRest
96
     */
97
    protected function getClassifierCopy()
98
    {
99
        // Clone the current classifier, so that
100
        // we don't mess up its variables while training
101
        // multiple instances of this classifier
102
        $classifier = clone $this;
103
        $classifier->reset();
104
105
        return $classifier;
106
    }
107
108
    /**
109
     * Groups all targets into two groups: Targets equal to
110
     * the given label and the others
111
     *
112
     * $targets is not passed by reference nor contains objects so this method
113
     * changes will not affect the caller $targets array.
114
     *
115
     * @param mixed $label
116
     *
117
     * @return array Binarized targets and target's labels
118
     */
119
    private function binarizeTargets(array $targets, $label) : array
120
    {
121
        $notLabel = "not_$label";
122
        foreach ($targets as $key => $target) {
123
            $targets[$key] = $target == $label ? $label : $notLabel;
124
        }
125
126
        $labels = [$label, $notLabel];
127
128
        return [$targets, $labels];
129
    }
130
131
    /**
132
     * @return mixed
133
     */
134
    protected function predictSample(array $sample)
135
    {
136
        if (count($this->allLabels) == 2) {
137
            return $this->classifiers[0]->predictSampleBinary($sample);
138
        }
139
140
        $probs = [];
141
142
        foreach ($this->classifiers as $label => $predictor) {
143
            $probs[$label] = $predictor->predictProbability($sample, $label);
144
        }
145
146
        arsort($probs, SORT_NUMERIC);
147
148
        return key($probs);
149
    }
150
151
    /**
152
     * Each classifier should implement this method instead of train(samples, targets)
153
     */
154
    abstract protected function trainBinary(array $samples, array $targets, array $labels);
155
156
    /**
157
     * To be overwritten by OneVsRest classifiers.
158
     *
159
     * @return void
160
     */
161
    abstract protected function resetBinary(): void;
162
163
    /**
164
     * Each classifier that make use of OvR approach should be able to
165
     * return a probability for a sample to belong to the given label.
166
     *
167
     * @return mixed
168
     */
169
    abstract protected function predictProbability(array $sample, string $label);
170
171
    /**
172
     * Each classifier should implement this method instead of predictSample()
173
     *
174
     * @return mixed
175
     */
176
    abstract protected function predictSampleBinary(array $sample);
177
}
178