1
|
|
|
<?php |
2
|
|
|
|
3
|
|
|
declare(strict_types=1); |
4
|
|
|
|
5
|
|
|
namespace Phpml\Helper; |
6
|
|
|
|
7
|
|
|
use Phpml\Classification\Classifier; |
8
|
|
|
|
9
|
|
|
trait OneVsRest |
10
|
|
|
{ |
11
|
|
|
/** |
12
|
|
|
* @var array |
13
|
|
|
*/ |
14
|
|
|
protected $classifiers = []; |
15
|
|
|
|
16
|
|
|
/** |
17
|
|
|
* All provided training targets' labels. |
18
|
|
|
* |
19
|
|
|
* @var array |
20
|
|
|
*/ |
21
|
|
|
protected $allLabels = []; |
22
|
|
|
|
23
|
|
|
/** |
24
|
|
|
* @var array |
25
|
|
|
*/ |
26
|
|
|
protected $costValues = []; |
27
|
|
|
|
28
|
|
|
/** |
29
|
|
|
* Train a binary classifier in the OvR style |
30
|
|
|
*/ |
31
|
|
|
public function train(array $samples, array $targets): void |
32
|
|
|
{ |
33
|
|
|
// Clears previous stuff. |
34
|
|
|
$this->reset(); |
35
|
|
|
|
36
|
|
|
$this->trainByLabel($samples, $targets); |
37
|
|
|
} |
38
|
|
|
|
39
|
|
|
/** |
40
|
|
|
* Resets the classifier and the vars internally used by OneVsRest to create multiple classifiers. |
41
|
|
|
*/ |
42
|
|
|
public function reset(): void |
43
|
|
|
{ |
44
|
|
|
$this->classifiers = []; |
45
|
|
|
$this->allLabels = []; |
46
|
|
|
$this->costValues = []; |
47
|
|
|
|
48
|
|
|
$this->resetBinary(); |
49
|
|
|
} |
50
|
|
|
|
51
|
|
|
protected function trainByLabel(array $samples, array $targets, array $allLabels = []): void |
52
|
|
|
{ |
53
|
|
|
// Overwrites the current value if it exist. $allLabels must be provided for each partialTrain run. |
54
|
|
|
$this->allLabels = count($allLabels) === 0 ? array_keys(array_count_values($targets)) : $allLabels; |
55
|
|
|
sort($this->allLabels, SORT_STRING); |
56
|
|
|
|
57
|
|
|
// If there are only two targets, then there is no need to perform OvR |
58
|
|
|
if (count($this->allLabels) === 2) { |
59
|
|
|
// Init classifier if required. |
60
|
|
|
if (count($this->classifiers) === 0) { |
61
|
|
|
$this->classifiers[0] = $this->getClassifierCopy(); |
62
|
|
|
} |
63
|
|
|
|
64
|
|
|
$this->classifiers[0]->trainBinary($samples, $targets, $this->allLabels); |
65
|
|
|
} else { |
66
|
|
|
// Train a separate classifier for each label and memorize them |
67
|
|
|
|
68
|
|
|
foreach ($this->allLabels as $label) { |
69
|
|
|
// Init classifier if required. |
70
|
|
|
if (!isset($this->classifiers[$label])) { |
71
|
|
|
$this->classifiers[$label] = $this->getClassifierCopy(); |
72
|
|
|
} |
73
|
|
|
|
74
|
|
|
[$binarizedTargets, $classifierLabels] = $this->binarizeTargets($targets, $label); |
75
|
|
|
$this->classifiers[$label]->trainBinary($samples, $binarizedTargets, $classifierLabels); |
76
|
|
|
} |
77
|
|
|
} |
78
|
|
|
|
79
|
|
|
// If the underlying classifier is capable of giving the cost values |
80
|
|
|
// during the training, then assign it to the relevant variable |
81
|
|
|
// Adding just the first classifier cost values to avoid complex average calculations. |
82
|
|
|
$classifierref = reset($this->classifiers); |
83
|
|
|
if (method_exists($classifierref, 'getCostValues')) { |
84
|
|
|
$this->costValues = $classifierref->getCostValues(); |
85
|
|
|
} |
86
|
|
|
} |
87
|
|
|
|
88
|
|
|
/** |
89
|
|
|
* Returns an instance of the current class after cleaning up OneVsRest stuff. |
90
|
|
|
*/ |
91
|
|
|
protected function getClassifierCopy(): Classifier |
92
|
|
|
{ |
93
|
|
|
// Clone the current classifier, so that |
94
|
|
|
// we don't mess up its variables while training |
95
|
|
|
// multiple instances of this classifier |
96
|
|
|
$classifier = clone $this; |
97
|
|
|
$classifier->reset(); |
98
|
|
|
|
99
|
|
|
return $classifier; |
|
|
|
|
100
|
|
|
} |
101
|
|
|
|
102
|
|
|
/** |
103
|
|
|
* @return mixed |
104
|
|
|
*/ |
105
|
|
|
protected function predictSample(array $sample) |
106
|
|
|
{ |
107
|
|
|
if (count($this->allLabels) === 2) { |
108
|
|
|
return $this->classifiers[0]->predictSampleBinary($sample); |
109
|
|
|
} |
110
|
|
|
|
111
|
|
|
$probs = []; |
112
|
|
|
|
113
|
|
|
foreach ($this->classifiers as $label => $predictor) { |
114
|
|
|
$probs[$label] = $predictor->predictProbability($sample, $label); |
115
|
|
|
} |
116
|
|
|
|
117
|
|
|
arsort($probs, SORT_NUMERIC); |
118
|
|
|
|
119
|
|
|
return key($probs); |
120
|
|
|
} |
121
|
|
|
|
122
|
|
|
/** |
123
|
|
|
* Each classifier should implement this method instead of train(samples, targets) |
124
|
|
|
*/ |
125
|
|
|
abstract protected function trainBinary(array $samples, array $targets, array $labels); |
126
|
|
|
|
127
|
|
|
/** |
128
|
|
|
* To be overwritten by OneVsRest classifiers. |
129
|
|
|
*/ |
130
|
|
|
abstract protected function resetBinary(): void; |
131
|
|
|
|
132
|
|
|
/** |
133
|
|
|
* Each classifier that make use of OvR approach should be able to |
134
|
|
|
* return a probability for a sample to belong to the given label. |
135
|
|
|
* |
136
|
|
|
* @return mixed |
137
|
|
|
*/ |
138
|
|
|
abstract protected function predictProbability(array $sample, string $label); |
139
|
|
|
|
140
|
|
|
/** |
141
|
|
|
* Each classifier should implement this method instead of predictSample() |
142
|
|
|
* |
143
|
|
|
* @return mixed |
144
|
|
|
*/ |
145
|
|
|
abstract protected function predictSampleBinary(array $sample); |
146
|
|
|
|
147
|
|
|
/** |
148
|
|
|
* Groups all targets into two groups: Targets equal to |
149
|
|
|
* the given label and the others |
150
|
|
|
* |
151
|
|
|
* $targets is not passed by reference nor contains objects so this method |
152
|
|
|
* changes will not affect the caller $targets array. |
153
|
|
|
* |
154
|
|
|
* @param mixed $label |
155
|
|
|
* |
156
|
|
|
* @return array Binarized targets and target's labels |
157
|
|
|
*/ |
158
|
|
|
private function binarizeTargets(array $targets, $label): array |
159
|
|
|
{ |
160
|
|
|
$notLabel = "not_${label}"; |
161
|
|
|
foreach ($targets as $key => $target) { |
162
|
|
|
$targets[$key] = $target == $label ? $label : $notLabel; |
163
|
|
|
} |
164
|
|
|
|
165
|
|
|
$labels = [$label, $notLabel]; |
166
|
|
|
|
167
|
|
|
return [$targets, $labels]; |
168
|
|
|
} |
169
|
|
|
} |
170
|
|
|
|