1
|
|
|
<?php |
2
|
|
|
|
3
|
|
|
declare(strict_types=1); |
4
|
|
|
|
5
|
|
|
namespace Phpml\Helper; |
6
|
|
|
|
7
|
|
|
trait OneVsRest |
8
|
|
|
{ |
9
|
|
|
|
10
|
|
|
/** |
11
|
|
|
* @var array |
12
|
|
|
*/ |
13
|
|
|
protected $classifiers = []; |
14
|
|
|
|
15
|
|
|
/** |
16
|
|
|
* All provided training targets' labels. |
17
|
|
|
* |
18
|
|
|
* @var array |
19
|
|
|
*/ |
20
|
|
|
protected $allLabels = []; |
21
|
|
|
|
22
|
|
|
/** |
23
|
|
|
* @var array |
24
|
|
|
*/ |
25
|
|
|
protected $costValues = []; |
26
|
|
|
|
27
|
|
|
/** |
28
|
|
|
* Train a binary classifier in the OvR style |
29
|
|
|
* |
30
|
|
|
* @param array $samples |
31
|
|
|
* @param array $targets |
32
|
|
|
*/ |
33
|
|
|
public function train(array $samples, array $targets) |
34
|
|
|
{ |
35
|
|
|
// Clears previous stuff. |
36
|
|
|
$this->reset(); |
37
|
|
|
|
38
|
|
|
return $this->trainBylabel($samples, $targets); |
39
|
|
|
} |
40
|
|
|
|
41
|
|
|
/** |
42
|
|
|
* @param array $samples |
43
|
|
|
* @param array $targets |
44
|
|
|
* @param array $allLabels All training set labels |
45
|
|
|
* @return void |
46
|
|
|
*/ |
47
|
|
|
protected function trainByLabel(array $samples, array $targets, array $allLabels = array()) |
48
|
|
|
{ |
49
|
|
|
|
50
|
|
|
// Overwrites the current value if it exist. $allLabels must be provided for each partialTrain run. |
51
|
|
|
if (!empty($allLabels)) { |
52
|
|
|
$this->allLabels = $allLabels; |
53
|
|
|
} else { |
54
|
|
|
$this->allLabels = array_keys(array_count_values($targets)); |
55
|
|
|
} |
56
|
|
|
sort($this->allLabels, SORT_STRING); |
57
|
|
|
|
58
|
|
|
// If there are only two targets, then there is no need to perform OvR |
59
|
|
|
if (count($this->allLabels) == 2) { |
60
|
|
|
|
61
|
|
|
// Init classifier if required. |
62
|
|
|
if (empty($this->classifiers)) { |
63
|
|
|
$this->classifiers[0] = $this->getClassifierCopy(); |
64
|
|
|
} |
65
|
|
|
|
66
|
|
|
$this->classifiers[0]->trainBinary($samples, $targets, $this->allLabels); |
67
|
|
|
} else { |
68
|
|
|
// Train a separate classifier for each label and memorize them |
69
|
|
|
|
70
|
|
|
foreach ($this->allLabels as $label) { |
71
|
|
|
|
72
|
|
|
// Init classifier if required. |
73
|
|
|
if (empty($this->classifiers[$label])) { |
74
|
|
|
$this->classifiers[$label] = $this->getClassifierCopy(); |
75
|
|
|
} |
76
|
|
|
|
77
|
|
|
list($binarizedTargets, $classifierLabels) = $this->binarizeTargets($targets, $label); |
78
|
|
|
$this->classifiers[$label]->trainBinary($samples, $binarizedTargets, $classifierLabels); |
79
|
|
|
} |
80
|
|
|
} |
81
|
|
|
|
82
|
|
|
// If the underlying classifier is capable of giving the cost values |
83
|
|
|
// during the training, then assign it to the relevant variable |
84
|
|
|
// Adding just the first classifier cost values to avoid complex average calculations. |
85
|
|
|
$classifierref = reset($this->classifiers); |
86
|
|
|
if (method_exists($classifierref, 'getCostValues')) { |
87
|
|
|
$this->costValues = $classifierref->getCostValues(); |
88
|
|
|
} |
89
|
|
|
} |
90
|
|
|
|
91
|
|
|
/** |
92
|
|
|
* Resets the classifier and the vars internally used by OneVsRest to create multiple classifiers. |
93
|
|
|
*/ |
94
|
|
|
public function reset() |
95
|
|
|
{ |
96
|
|
|
$this->classifiers = []; |
97
|
|
|
$this->allLabels = []; |
98
|
|
|
$this->costValues = []; |
99
|
|
|
|
100
|
|
|
$this->resetBinary(); |
101
|
|
|
} |
102
|
|
|
|
103
|
|
|
/** |
104
|
|
|
* Returns an instance of the current class after cleaning up OneVsRest stuff. |
105
|
|
|
* |
106
|
|
|
* @return \Phpml\Estimator |
107
|
|
|
*/ |
108
|
|
|
protected function getClassifierCopy() |
109
|
|
|
{ |
110
|
|
|
|
111
|
|
|
// Clone the current classifier, so that |
112
|
|
|
// we don't mess up its variables while training |
113
|
|
|
// multiple instances of this classifier |
114
|
|
|
$classifier = clone $this; |
115
|
|
|
$classifier->reset(); |
116
|
|
|
return $classifier; |
117
|
|
|
} |
118
|
|
|
|
119
|
|
|
/** |
120
|
|
|
* Groups all targets into two groups: Targets equal to |
121
|
|
|
* the given label and the others |
122
|
|
|
* |
123
|
|
|
* $targets is not passed by reference nor contains objects so this method |
124
|
|
|
* changes will not affect the caller $targets array. |
125
|
|
|
* |
126
|
|
|
* @param array $targets |
127
|
|
|
* @param mixed $label |
128
|
|
|
* @return array Binarized targets and target's labels |
129
|
|
|
*/ |
130
|
|
|
private function binarizeTargets($targets, $label) |
131
|
|
|
{ |
132
|
|
|
|
133
|
|
|
$notLabel = "not_$label"; |
134
|
|
|
foreach ($targets as $key => $target) { |
135
|
|
|
$targets[$key] = $target == $label ? $label : $notLabel; |
136
|
|
|
} |
137
|
|
|
|
138
|
|
|
$labels = array($label, $notLabel); |
139
|
|
|
return array($targets, $labels); |
140
|
|
|
} |
141
|
|
|
|
142
|
|
|
|
143
|
|
|
/** |
144
|
|
|
* @param array $sample |
145
|
|
|
* |
146
|
|
|
* @return mixed |
147
|
|
|
*/ |
148
|
|
|
protected function predictSample(array $sample) |
149
|
|
|
{ |
150
|
|
|
if (count($this->allLabels) == 2) { |
151
|
|
|
return $this->classifiers[0]->predictSampleBinary($sample); |
152
|
|
|
} |
153
|
|
|
|
154
|
|
|
$probs = []; |
155
|
|
|
|
156
|
|
|
foreach ($this->classifiers as $label => $predictor) { |
157
|
|
|
$probs[$label] = $predictor->predictProbability($sample, $label); |
158
|
|
|
} |
159
|
|
|
|
160
|
|
|
arsort($probs, SORT_NUMERIC); |
161
|
|
|
return key($probs); |
162
|
|
|
} |
163
|
|
|
|
164
|
|
|
/** |
165
|
|
|
* Each classifier should implement this method instead of train(samples, targets) |
166
|
|
|
* |
167
|
|
|
* @param array $samples |
168
|
|
|
* @param array $targets |
169
|
|
|
* @param array $labels |
170
|
|
|
*/ |
171
|
|
|
abstract protected function trainBinary(array $samples, array $targets, array $labels); |
172
|
|
|
|
173
|
|
|
/** |
174
|
|
|
* To be overwritten by OneVsRest classifiers. |
175
|
|
|
* |
176
|
|
|
* @return void |
177
|
|
|
*/ |
178
|
|
|
abstract protected function resetBinary(); |
179
|
|
|
|
180
|
|
|
/** |
181
|
|
|
* Each classifier that make use of OvR approach should be able to |
182
|
|
|
* return a probability for a sample to belong to the given label. |
183
|
|
|
* |
184
|
|
|
* @param array $sample |
185
|
|
|
* |
186
|
|
|
* @return mixed |
187
|
|
|
*/ |
188
|
|
|
abstract protected function predictProbability(array $sample, string $label); |
189
|
|
|
|
190
|
|
|
/** |
191
|
|
|
* Each classifier should implement this method instead of predictSample() |
192
|
|
|
* |
193
|
|
|
* @param array $sample |
194
|
|
|
* |
195
|
|
|
* @return mixed |
196
|
|
|
*/ |
197
|
|
|
abstract protected function predictSampleBinary(array $sample); |
198
|
|
|
} |
199
|
|
|
|