Passed
Pull Request — master (#78)
by
unknown
04:36
created

OneVsRest::reset()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 5
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 5
rs 9.4285
c 0
b 0
f 0
cc 1
eloc 4
nc 1
nop 0
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper;
6
7
trait OneVsRest
8
{
9
10
    /**
11
     * @var array
12
     */
13
    protected $classifiers = [];
14
15
    /**
16
     * All provided training targets' labels.
17
     *
18
     * @var array
19
     */
20
    protected $allLabels = [];
21
22
    /**
23
     * @var array
24
     */
25
    protected $costValues = [];
26
27
    /**
28
     * Train a binary classifier in the OvR style
29
     *
30
     * @param array $samples
31
     * @param array $targets
32
     */
33
    public function train(array $samples, array $targets)
34
    {
35
        return $this->trainBylabel($samples, $targets);
36
    }
37
38
    /**
39
     * @param array $samples
40
     * @param array $targets
41
     * @param array $allLabels All training set labels
42
     * @return void
43
     */
44
    protected function trainByLabel(array $samples, array $targets, array $allLabels = array())
45
    {
46
47
        // Overwrites the current value if it exist. $allLabels must be provided for each partialTrain run.
48
        if (!empty($allLabels)) {
49
            $this->allLabels = $allLabels;
50
        } else {
51
            $this->allLabels = array_keys(array_count_values($targets));
52
        }
53
        sort($this->allLabels, SORT_STRING);
54
55
        // If there are only two targets, then there is no need to perform OvR
56
        if (count($this->allLabels) == 2) {
57
58
            // Init classifier if required.
59
            if (empty($this->classifiers)) {
60
                $this->classifiers[0] = $this->getClassifierCopy();
61
            }
62
63
            $this->classifiers[0]->trainBinary($samples, $targets, $this->allLabels);
64
        } else {
65
            // Train a separate classifier for each label and memorize them
66
67
            foreach ($this->allLabels as $label) {
68
69
                // Init classifier if required.
70
                if (empty($this->classifiers[$label])) {
71
                    $this->classifiers[$label] = $this->getClassifierCopy();
72
                }
73
74
                list($binarizedTargets, $classifierLabels) = $this->binarizeTargets($targets, $label);
75
                $this->classifiers[$label]->trainBinary($samples, $binarizedTargets, $classifierLabels);
76
            }
77
        }
78
79
        // If the underlying classifier is capable of giving the cost values
80
        // during the training, then assign it to the relevant variable
81
        if (method_exists($this->classifiers[0], 'getCostValues')) {
82
            $this->costValues = $this->classifiers[0]->getCostValues();
83
        }
84
85
    }
86
87
    /**
88
     * Resets the internal vars used by OneVsRest instances.
89
     */
90
    public function reset() {
91
        $this->classifiers = [];
92
        $this->allLabels = [];
93
        $this->costValues = [];
94
    }
95
96
    /**
97
     * Returns an instance of the current class after cleaning up OneVsRest stuff.
98
     *
99
     * @return \Phpml\Estimator
100
     */
101
    protected function getClassifierCopy() {
102
103
        // Clone the current classifier, so that
104
        // we don't mess up its variables while training
105
        // multiple instances of this classifier
106
        $classifier = clone $this;
107
        $classifier->reset();
108
        return $classifier;
109
    }
110
111
    /**
112
     * Groups all targets into two groups: Targets equal to
113
     * the given label and the others
114
     *
115
     * $targets is not passed by reference nor contains objects so this method
116
     * changes will not affect the caller $targets array.
117
     *
118
     * @param array $targets
119
     * @param mixed $label
120
     * @return array Binarized targets and target's labels
121
     */
122
    private function binarizeTargets($targets, $label)
123
    {
124
125
        $notLabel = "not_$label";
126
        foreach ($targets as $key => $target) {
127
            $targets[$key] = $target == $label ? $label : $notLabel;
128
        }
129
130
        $labels = array($label, $notLabel);
131
        return array($targets, $labels);
132
    }
133
134
135
    /**
136
     * @param array $sample
137
     *
138
     * @return mixed
139
     */
140
    protected function predictSample(array $sample)
141
    {
142
        if (count($this->allLabels) == 2) {
143
            return $this->classifiers[0]->predictSampleBinary($sample);
144
        }
145
146
        $probs = [];
147
148
        foreach ($this->classifiers as $label => $predictor) {
149
            $probs[$label] = $predictor->predictProbability($sample, $label);
150
        }
151
152
        arsort($probs, SORT_NUMERIC);
153
        return key($probs);
154
    }
155
156
    /**
157
     * Each classifier should implement this method instead of train(samples, targets)
158
     *
159
     * @param array $samples
160
     * @param array $targets
161
     * @param array $labels
162
     */
163
    abstract protected function trainBinary(array $samples, array $targets, array $labels);
164
165
    /**
166
     * Each classifier that make use of OvR approach should be able to
167
     * return a probability for a sample to belong to the given label.
168
     *
169
     * @param array $sample
170
     *
171
     * @return mixed
172
     */
173
    abstract protected function predictProbability(array $sample, string $label);
174
175
    /**
176
     * Each classifier should implement this method instead of predictSample()
177
     *
178
     * @param array $sample
179
     *
180
     * @return mixed
181
     */
182
    abstract protected function predictSampleBinary(array $sample);
183
}
184