Passed
Push — master ( 331d4b...653c7c )
by Arkadiusz
02:19
created

src/Phpml/Helper/OneVsRest.php (1 issue)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper;
6
7
use Phpml\Classification\Classifier;
8
9
trait OneVsRest
10
{
11
    /**
12
     * @var array
13
     */
14
    protected $classifiers = [];
15
16
    /**
17
     * All provided training targets' labels.
18
     *
19
     * @var array
20
     */
21
    protected $allLabels = [];
22
23
    /**
24
     * @var array
25
     */
26
    protected $costValues = [];
27
28
    /**
29
     * Train a binary classifier in the OvR style
30
     */
31
    public function train(array $samples, array $targets): void
32
    {
33
        // Clears previous stuff.
34
        $this->reset();
35
36
        $this->trainBylabel($samples, $targets);
37
    }
38
39
    protected function trainByLabel(array $samples, array $targets, array $allLabels = []): void
40
    {
41
        // Overwrites the current value if it exist. $allLabels must be provided for each partialTrain run.
42
        if (!empty($allLabels)) {
43
            $this->allLabels = $allLabels;
44
        } else {
45
            $this->allLabels = array_keys(array_count_values($targets));
46
        }
47
        sort($this->allLabels, SORT_STRING);
48
49
        // If there are only two targets, then there is no need to perform OvR
50
        if (count($this->allLabels) == 2) {
51
            // Init classifier if required.
52
            if (empty($this->classifiers)) {
53
                $this->classifiers[0] = $this->getClassifierCopy();
54
            }
55
56
            $this->classifiers[0]->trainBinary($samples, $targets, $this->allLabels);
57
        } else {
58
            // Train a separate classifier for each label and memorize them
59
60
            foreach ($this->allLabels as $label) {
61
                // Init classifier if required.
62
                if (empty($this->classifiers[$label])) {
63
                    $this->classifiers[$label] = $this->getClassifierCopy();
64
                }
65
66
                [$binarizedTargets, $classifierLabels] = $this->binarizeTargets($targets, $label);
67
                $this->classifiers[$label]->trainBinary($samples, $binarizedTargets, $classifierLabels);
0 ignored issues
show
The variable $binarizedTargets does not exist. Did you mean $targets?

This check looks for variables that are accessed but have not been defined. It raises an issue if it finds another variable that has a similar name.

The variable may have been renamed without also renaming all references.

Loading history...
68
            }
69
        }
70
71
        // If the underlying classifier is capable of giving the cost values
72
        // during the training, then assign it to the relevant variable
73
        // Adding just the first classifier cost values to avoid complex average calculations.
74
        $classifierref = reset($this->classifiers);
75
        if (method_exists($classifierref, 'getCostValues')) {
76
            $this->costValues = $classifierref->getCostValues();
77
        }
78
    }
79
80
    /**
81
     * Resets the classifier and the vars internally used by OneVsRest to create multiple classifiers.
82
     */
83
    public function reset(): void
84
    {
85
        $this->classifiers = [];
86
        $this->allLabels = [];
87
        $this->costValues = [];
88
89
        $this->resetBinary();
90
    }
91
92
    /**
93
     * Returns an instance of the current class after cleaning up OneVsRest stuff.
94
     *
95
     * @return Classifier|OneVsRest
96
     */
97
    protected function getClassifierCopy()
98
    {
99
        // Clone the current classifier, so that
100
        // we don't mess up its variables while training
101
        // multiple instances of this classifier
102
        $classifier = clone $this;
103
        $classifier->reset();
104
105
        return $classifier;
106
    }
107
108
    /**
109
     * Groups all targets into two groups: Targets equal to
110
     * the given label and the others
111
     *
112
     * $targets is not passed by reference nor contains objects so this method
113
     * changes will not affect the caller $targets array.
114
     *
115
     * @param mixed $label
116
     *
117
     * @return array Binarized targets and target's labels
118
     */
119
    private function binarizeTargets(array $targets, $label) : array
120
    {
121
        $notLabel = "not_$label";
122
        foreach ($targets as $key => $target) {
123
            $targets[$key] = $target == $label ? $label : $notLabel;
124
        }
125
126
        $labels = [$label, $notLabel];
127
128
        return [$targets, $labels];
129
    }
130
131
    /**
132
     * @return mixed
133
     */
134
    protected function predictSample(array $sample)
135
    {
136
        if (count($this->allLabels) == 2) {
137
            return $this->classifiers[0]->predictSampleBinary($sample);
138
        }
139
140
        $probs = [];
141
142
        foreach ($this->classifiers as $label => $predictor) {
143
            $probs[$label] = $predictor->predictProbability($sample, $label);
144
        }
145
146
        arsort($probs, SORT_NUMERIC);
147
148
        return key($probs);
149
    }
150
151
    /**
152
     * Each classifier should implement this method instead of train(samples, targets)
153
     */
154
    abstract protected function trainBinary(array $samples, array $targets, array $labels);
155
156
    /**
157
     * To be overwritten by OneVsRest classifiers.
158
     *
159
     * @return void
160
     */
161
    abstract protected function resetBinary(): void;
162
163
    /**
164
     * Each classifier that make use of OvR approach should be able to
165
     * return a probability for a sample to belong to the given label.
166
     *
167
     * @return mixed
168
     */
169
    abstract protected function predictProbability(array $sample, string $label);
170
171
    /**
172
     * Each classifier should implement this method instead of predictSample()
173
     *
174
     * @return mixed
175
     */
176
    abstract protected function predictSampleBinary(array $sample);
177
}
178