Test Failed
Branch master (01bb82)
by Arkadiusz
05:58 queued 03:02
created

OneVsRest::predictSampleBinary()

Size

Total Lines 1

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 1
c 0
b 0
f 0
nc 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper;
6
7
trait OneVsRest
8
{
9
    /**
10
     * @var array
11
     */
12
    protected $samples = [];
13
14
    /**
15
     * @var array
16
     */
17
    protected $targets = [];
18
    
19
    /**
20
     * @var array
21
     */
22
    protected $classifiers;
23
24
    /**
25
     * @var array
26
     */
27
    protected $labels;
28
29
    /**
30
     * Train a binary classifier in the OvR style
31
     *
32
     * @param array $samples
33
     * @param array $targets
34
     */
35
    public function train(array $samples, array $targets)
36
    {
37
        // Clone the current classifier, so that
38
        // we don't mess up its variables while training
39
        // multiple instances of this classifier
40
        $classifier = clone $this;
41
        $this->classifiers = [];
42
43
        // If there are only two targets, then there is no need to perform OvR
44
        $this->labels = array_keys(array_count_values($targets));
45
        if (count($this->labels) == 2) {
46
            $classifier->trainBinary($samples, $targets);
47
            $this->classifiers[] = $classifier;
48
        } else {
49
            // Train a separate classifier for each label and memorize them
50
            $this->samples = $samples;
51
            $this->targets = $targets;
52
            foreach ($this->labels as $label) {
53
                $predictor = clone $classifier;
54
                $targets = $this->binarizeTargets($label);
55
                $predictor->trainBinary($samples, $targets);
56
                $this->classifiers[$label] = $predictor;
57
            }
58
        }
59
    }
60
61
    /**
62
     * Groups all targets into two groups: Targets equal to
63
     * the given label and the others
64
     *
65
     * @param mixed $label
66
     */
67
    private function binarizeTargets($label)
68
    {
69
        $targets = [];
70
71
        foreach ($this->targets as $target) {
72
            $targets[] = $target == $label ? $label : "not_$label";
73
        }
74
75
        return $targets;
76
    }
77
78
79
    /**
80
     * @param array $sample
81
     *
82
     * @return mixed
83
     */
84
    protected function predictSample(array $sample)
85
    {
86
        if (count($this->labels) == 2) {
87
            return $this->classifiers[0]->predictSampleBinary($sample);
88
        }
89
90
        $probs = [];
91
92
        foreach ($this->classifiers as $label => $predictor) {
93
            $probs[$label] = $predictor->predictProbability($sample, $label);
94
        }
95
96
        arsort($probs, SORT_NUMERIC);
97
        return key($probs);
98
    }
99
100
    /**
101
     * Each classifier should implement this method instead of train(samples, targets)
102
     *
103
     * @param array $samples
104
     * @param array $targets
105
     */
106
    abstract protected function trainBinary(array $samples, array $targets);
107
108
    /**
109
     * Each classifier that make use of OvR approach should be able to
110
     * return a probability for a sample to belong to the given label.
111
     *
112
     * @param array $sample
113
     *
114
     * @return mixed
115
     */
116
    abstract protected function predictProbability(array $sample, string $label);
117
118
    /**
119
     * Each classifier should implement this method instead of predictSample()
120
     *
121
     * @param array $sample
122
     *
123
     * @return mixed
124
     */
125
    abstract protected function predictSampleBinary(array $sample);
126
}
127