Passed
Push — master ( c0463a...e1854d )
by Arkadiusz
02:54
created

OneVsRest::trainByLabel()   C

Complexity

Conditions 7
Paths 12

Size

Total Lines 43
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 43
rs 6.7272
c 0
b 0
f 0
cc 7
eloc 19
nc 12
nop 3
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper;
6
7
trait OneVsRest
8
{
9
10
    /**
11
     * @var array
12
     */
13
    protected $classifiers = [];
14
15
    /**
16
     * All provided training targets' labels.
17
     *
18
     * @var array
19
     */
20
    protected $allLabels = [];
21
22
    /**
23
     * @var array
24
     */
25
    protected $costValues = [];
26
27
    /**
28
     * Train a binary classifier in the OvR style
29
     *
30
     * @param array $samples
31
     * @param array $targets
32
     */
33
    public function train(array $samples, array $targets)
34
    {
35
        // Clears previous stuff.
36
        $this->reset();
37
38
        return $this->trainBylabel($samples, $targets);
39
    }
40
41
    /**
42
     * @param array $samples
43
     * @param array $targets
44
     * @param array $allLabels All training set labels
45
     * @return void
46
     */
47
    protected function trainByLabel(array $samples, array $targets, array $allLabels = array())
48
    {
49
50
        // Overwrites the current value if it exist. $allLabels must be provided for each partialTrain run.
51
        if (!empty($allLabels)) {
52
            $this->allLabels = $allLabels;
53
        } else {
54
            $this->allLabels = array_keys(array_count_values($targets));
55
        }
56
        sort($this->allLabels, SORT_STRING);
57
58
        // If there are only two targets, then there is no need to perform OvR
59
        if (count($this->allLabels) == 2) {
60
61
            // Init classifier if required.
62
            if (empty($this->classifiers)) {
63
                $this->classifiers[0] = $this->getClassifierCopy();
64
            }
65
66
            $this->classifiers[0]->trainBinary($samples, $targets, $this->allLabels);
67
        } else {
68
            // Train a separate classifier for each label and memorize them
69
70
            foreach ($this->allLabels as $label) {
71
72
                // Init classifier if required.
73
                if (empty($this->classifiers[$label])) {
74
                    $this->classifiers[$label] = $this->getClassifierCopy();
75
                }
76
77
                list($binarizedTargets, $classifierLabels) = $this->binarizeTargets($targets, $label);
78
                $this->classifiers[$label]->trainBinary($samples, $binarizedTargets, $classifierLabels);
79
            }
80
        }
81
82
        // If the underlying classifier is capable of giving the cost values
83
        // during the training, then assign it to the relevant variable
84
        // Adding just the first classifier cost values to avoid complex average calculations.
85
        $classifierref = reset($this->classifiers);
86
        if (method_exists($classifierref, 'getCostValues')) {
87
            $this->costValues = $classifierref->getCostValues();
88
        }
89
    }
90
91
    /**
92
     * Resets the classifier and the vars internally used by OneVsRest to create multiple classifiers.
93
     */
94
    public function reset()
95
    {
96
        $this->classifiers = [];
97
        $this->allLabels = [];
98
        $this->costValues = [];
99
100
        $this->resetBinary();
101
    }
102
103
    /**
104
     * Returns an instance of the current class after cleaning up OneVsRest stuff.
105
     *
106
     * @return \Phpml\Estimator
107
     */
108
    protected function getClassifierCopy()
109
    {
110
111
        // Clone the current classifier, so that
112
        // we don't mess up its variables while training
113
        // multiple instances of this classifier
114
        $classifier = clone $this;
115
        $classifier->reset();
116
        return $classifier;
117
    }
118
119
    /**
120
     * Groups all targets into two groups: Targets equal to
121
     * the given label and the others
122
     *
123
     * $targets is not passed by reference nor contains objects so this method
124
     * changes will not affect the caller $targets array.
125
     *
126
     * @param array $targets
127
     * @param mixed $label
128
     * @return array Binarized targets and target's labels
129
     */
130
    private function binarizeTargets($targets, $label)
131
    {
132
133
        $notLabel = "not_$label";
134
        foreach ($targets as $key => $target) {
135
            $targets[$key] = $target == $label ? $label : $notLabel;
136
        }
137
138
        $labels = array($label, $notLabel);
139
        return array($targets, $labels);
140
    }
141
142
143
    /**
144
     * @param array $sample
145
     *
146
     * @return mixed
147
     */
148
    protected function predictSample(array $sample)
149
    {
150
        if (count($this->allLabels) == 2) {
151
            return $this->classifiers[0]->predictSampleBinary($sample);
152
        }
153
154
        $probs = [];
155
156
        foreach ($this->classifiers as $label => $predictor) {
157
            $probs[$label] = $predictor->predictProbability($sample, $label);
158
        }
159
160
        arsort($probs, SORT_NUMERIC);
161
        return key($probs);
162
    }
163
164
    /**
165
     * Each classifier should implement this method instead of train(samples, targets)
166
     *
167
     * @param array $samples
168
     * @param array $targets
169
     * @param array $labels
170
     */
171
    abstract protected function trainBinary(array $samples, array $targets, array $labels);
172
173
    /**
174
     * To be overwritten by OneVsRest classifiers.
175
     *
176
     * @return void
177
     */
178
    abstract protected function resetBinary();
179
180
    /**
181
     * Each classifier that make use of OvR approach should be able to
182
     * return a probability for a sample to belong to the given label.
183
     *
184
     * @param array $sample
185
     *
186
     * @return mixed
187
     */
188
    abstract protected function predictProbability(array $sample, string $label);
189
190
    /**
191
     * Each classifier should implement this method instead of predictSample()
192
     *
193
     * @param array $sample
194
     *
195
     * @return mixed
196
     */
197
    abstract protected function predictSampleBinary(array $sample);
198
}
199