Issues (14)

src/Helper/OneVsRest.php (1 issue)

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper;
6
7
use Phpml\Classification\Classifier;
8
9
trait OneVsRest
10
{
11
    /**
12
     * @var array
13
     */
14
    protected $classifiers = [];
15
16
    /**
17
     * All provided training targets' labels.
18
     *
19
     * @var array
20
     */
21
    protected $allLabels = [];
22
23
    /**
24
     * @var array
25
     */
26
    protected $costValues = [];
27
28
    /**
29
     * Train a binary classifier in the OvR style
30
     */
31
    public function train(array $samples, array $targets): void
32
    {
33
        // Clears previous stuff.
34
        $this->reset();
35
36
        $this->trainByLabel($samples, $targets);
37
    }
38
39
    /**
40
     * Resets the classifier and the vars internally used by OneVsRest to create multiple classifiers.
41
     */
42
    public function reset(): void
43
    {
44
        $this->classifiers = [];
45
        $this->allLabels = [];
46
        $this->costValues = [];
47
48
        $this->resetBinary();
49
    }
50
51
    protected function trainByLabel(array $samples, array $targets, array $allLabels = []): void
52
    {
53
        // Overwrites the current value if it exist. $allLabels must be provided for each partialTrain run.
54
        $this->allLabels = count($allLabels) === 0 ? array_keys(array_count_values($targets)) : $allLabels;
55
        sort($this->allLabels, SORT_STRING);
56
57
        // If there are only two targets, then there is no need to perform OvR
58
        if (count($this->allLabels) === 2) {
59
            // Init classifier if required.
60
            if (count($this->classifiers) === 0) {
61
                $this->classifiers[0] = $this->getClassifierCopy();
62
            }
63
64
            $this->classifiers[0]->trainBinary($samples, $targets, $this->allLabels);
65
        } else {
66
            // Train a separate classifier for each label and memorize them
67
68
            foreach ($this->allLabels as $label) {
69
                // Init classifier if required.
70
                if (!isset($this->classifiers[$label])) {
71
                    $this->classifiers[$label] = $this->getClassifierCopy();
72
                }
73
74
                [$binarizedTargets, $classifierLabels] = $this->binarizeTargets($targets, $label);
75
                $this->classifiers[$label]->trainBinary($samples, $binarizedTargets, $classifierLabels);
76
            }
77
        }
78
79
        // If the underlying classifier is capable of giving the cost values
80
        // during the training, then assign it to the relevant variable
81
        // Adding just the first classifier cost values to avoid complex average calculations.
82
        $classifierref = reset($this->classifiers);
83
        if (method_exists($classifierref, 'getCostValues')) {
84
            $this->costValues = $classifierref->getCostValues();
85
        }
86
    }
87
88
    /**
89
     * Returns an instance of the current class after cleaning up OneVsRest stuff.
90
     */
91
    protected function getClassifierCopy(): Classifier
92
    {
93
        // Clone the current classifier, so that
94
        // we don't mess up its variables while training
95
        // multiple instances of this classifier
96
        $classifier = clone $this;
97
        $classifier->reset();
98
99
        return $classifier;
0 ignored issues
show
Bug Best Practice introduced by
The expression return $classifier returns the type Phpml\Helper\OneVsRest which is incompatible with the type-hinted return Phpml\Classification\Classifier.
Loading history...
100
    }
101
102
    /**
103
     * @return mixed
104
     */
105
    protected function predictSample(array $sample)
106
    {
107
        if (count($this->allLabels) === 2) {
108
            return $this->classifiers[0]->predictSampleBinary($sample);
109
        }
110
111
        $probs = [];
112
113
        foreach ($this->classifiers as $label => $predictor) {
114
            $probs[$label] = $predictor->predictProbability($sample, $label);
115
        }
116
117
        arsort($probs, SORT_NUMERIC);
118
119
        return key($probs);
120
    }
121
122
    /**
123
     * Each classifier should implement this method instead of train(samples, targets)
124
     */
125
    abstract protected function trainBinary(array $samples, array $targets, array $labels);
126
127
    /**
128
     * To be overwritten by OneVsRest classifiers.
129
     */
130
    abstract protected function resetBinary(): void;
131
132
    /**
133
     * Each classifier that make use of OvR approach should be able to
134
     * return a probability for a sample to belong to the given label.
135
     *
136
     * @return mixed
137
     */
138
    abstract protected function predictProbability(array $sample, string $label);
139
140
    /**
141
     * Each classifier should implement this method instead of predictSample()
142
     *
143
     * @return mixed
144
     */
145
    abstract protected function predictSampleBinary(array $sample);
146
147
    /**
148
     * Groups all targets into two groups: Targets equal to
149
     * the given label and the others
150
     *
151
     * $targets is not passed by reference nor contains objects so this method
152
     * changes will not affect the caller $targets array.
153
     *
154
     * @param mixed $label
155
     *
156
     * @return array Binarized targets and target's labels
157
     */
158
    private function binarizeTargets(array $targets, $label): array
159
    {
160
        $notLabel = "not_${label}";
161
        foreach ($targets as $key => $target) {
162
            $targets[$key] = $target == $label ? $label : $notLabel;
163
        }
164
165
        $labels = [$label, $notLabel];
166
167
        return [$targets, $labels];
168
    }
169
}
170