Passed
Push — master ( e83f7b...d953ef )
by Arkadiusz
03:28
created

src/Phpml/Helper/OneVsRest.php (1 issue)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper;
6
7
use Phpml\Classification\Classifier;
8
9
trait OneVsRest
10
{
11
    /**
12
     * @var array
13
     */
14
    protected $classifiers = [];
15
16
    /**
17
     * All provided training targets' labels.
18
     *
19
     * @var array
20
     */
21
    protected $allLabels = [];
22
23
    /**
24
     * @var array
25
     */
26
    protected $costValues = [];
27
28
    /**
29
     * Train a binary classifier in the OvR style
30
     */
31
    public function train(array $samples, array $targets): void
32
    {
33
        // Clears previous stuff.
34
        $this->reset();
35
36
        $this->trainByLabel($samples, $targets);
37
    }
38
39
    /**
40
     * Resets the classifier and the vars internally used by OneVsRest to create multiple classifiers.
41
     */
42
    public function reset(): void
43
    {
44
        $this->classifiers = [];
45
        $this->allLabels = [];
46
        $this->costValues = [];
47
48
        $this->resetBinary();
49
    }
50
51
    protected function trainByLabel(array $samples, array $targets, array $allLabels = []): void
52
    {
53
        // Overwrites the current value if it exist. $allLabels must be provided for each partialTrain run.
54
        if (!empty($allLabels)) {
55
            $this->allLabels = $allLabels;
56
        } else {
57
            $this->allLabels = array_keys(array_count_values($targets));
58
        }
59
60
        sort($this->allLabels, SORT_STRING);
61
62
        // If there are only two targets, then there is no need to perform OvR
63
        if (count($this->allLabels) == 2) {
64
            // Init classifier if required.
65
            if (empty($this->classifiers)) {
66
                $this->classifiers[0] = $this->getClassifierCopy();
67
            }
68
69
            $this->classifiers[0]->trainBinary($samples, $targets, $this->allLabels);
70
        } else {
71
            // Train a separate classifier for each label and memorize them
72
73
            foreach ($this->allLabels as $label) {
74
                // Init classifier if required.
75
                if (empty($this->classifiers[$label])) {
76
                    $this->classifiers[$label] = $this->getClassifierCopy();
77
                }
78
79
                [$binarizedTargets, $classifierLabels] = $this->binarizeTargets($targets, $label);
80
                $this->classifiers[$label]->trainBinary($samples, $binarizedTargets, $classifierLabels);
0 ignored issues
show
The variable $binarizedTargets does not exist. Did you mean $targets?

This check looks for variables that are accessed but have not been defined. It raises an issue if it finds another variable that has a similar name.

The variable may have been renamed without also renaming all references.

Loading history...
81
            }
82
        }
83
84
        // If the underlying classifier is capable of giving the cost values
85
        // during the training, then assign it to the relevant variable
86
        // Adding just the first classifier cost values to avoid complex average calculations.
87
        $classifierref = reset($this->classifiers);
88
        if (method_exists($classifierref, 'getCostValues')) {
89
            $this->costValues = $classifierref->getCostValues();
90
        }
91
    }
92
93
    /**
94
     * Returns an instance of the current class after cleaning up OneVsRest stuff.
95
     *
96
     * @return Classifier|OneVsRest
97
     */
98
    protected function getClassifierCopy()
99
    {
100
        // Clone the current classifier, so that
101
        // we don't mess up its variables while training
102
        // multiple instances of this classifier
103
        $classifier = clone $this;
104
        $classifier->reset();
105
106
        return $classifier;
107
    }
108
109
    /**
110
     * @return mixed
111
     */
112
    protected function predictSample(array $sample)
113
    {
114
        if (count($this->allLabels) == 2) {
115
            return $this->classifiers[0]->predictSampleBinary($sample);
116
        }
117
118
        $probs = [];
119
120
        foreach ($this->classifiers as $label => $predictor) {
121
            $probs[$label] = $predictor->predictProbability($sample, $label);
122
        }
123
124
        arsort($probs, SORT_NUMERIC);
125
126
        return key($probs);
127
    }
128
129
    /**
130
     * Each classifier should implement this method instead of train(samples, targets)
131
     */
132
    abstract protected function trainBinary(array $samples, array $targets, array $labels);
133
134
    /**
135
     * To be overwritten by OneVsRest classifiers.
136
     */
137
    abstract protected function resetBinary(): void;
138
139
    /**
140
     * Each classifier that make use of OvR approach should be able to
141
     * return a probability for a sample to belong to the given label.
142
     *
143
     * @return mixed
144
     */
145
    abstract protected function predictProbability(array $sample, string $label);
146
147
    /**
148
     * Each classifier should implement this method instead of predictSample()
149
     *
150
     * @return mixed
151
     */
152
    abstract protected function predictSampleBinary(array $sample);
153
154
    /**
155
     * Groups all targets into two groups: Targets equal to
156
     * the given label and the others
157
     *
158
     * $targets is not passed by reference nor contains objects so this method
159
     * changes will not affect the caller $targets array.
160
     *
161
     * @param mixed $label
162
     *
163
     * @return array Binarized targets and target's labels
164
     */
165
    private function binarizeTargets(array $targets, $label): array
166
    {
167
        $notLabel = "not_${label}";
168
        foreach ($targets as $key => $target) {
169
            $targets[$key] = $target == $label ? $label : $notLabel;
170
        }
171
172
        $labels = [$label, $notLabel];
173
174
        return [$targets, $labels];
175
    }
176
}
177