Passed
Push — develop ( 63d496...62ec4e )
by Arkadiusz
02:18
created

KNearestNeighbors::predictSample()   A

Complexity

Conditions 2
Paths 2

Size

Total Lines 15
Code Lines 8

Duplication

Lines 0
Ratio 0 %
Metric Value
dl 0
loc 15
rs 9.4285
cc 2
eloc 8
nc 2
nop 1
1
<?php
2
3
declare (strict_types = 1);
4
5
namespace Phpml\Classifier;
6
7
use Phpml\Metric\Distance;
8
9
class KNearestNeighbors implements Classifier
10
{
11
    /**
12
     * @var int
13
     */
14
    private $k;
15
16
    /**
17
     * @var array
18
     */
19
    private $samples;
20
21
    /**
22
     * @var array
23
     */
24
    private $labels;
25
26
    /**
27
     * @param int $k
28
     */
29
    public function __construct(int $k = 3)
30
    {
31
        $this->k = $k;
32
        $this->samples = [];
33
        $this->labels = [];
34
    }
35
36
    /**
37
     * @param array $samples
38
     * @param array $labels
39
     */
40
    public function train(array $samples, array $labels)
41
    {
42
        $this->samples = $samples;
43
        $this->labels = $labels;
44
    }
45
46
    /**
47
     * @param array $samples
48
     *
49
     * @return mixed
50
     */
51
    public function predict(array $samples)
52
    {
53
        if (!is_array($samples[0])) {
54
            $predicted = $this->predictSample($samples);
55
        } else {
56
            $predicted = [];
57
            foreach ($samples as $index => $sample) {
58
                $predicted[$index] = $this->predictSample($sample);
59
            }
60
        }
61
62
        return $predicted;
63
    }
64
65
    /**
66
     * @param array $sample
67
     *
68
     * @return mixed
69
     */
70
    private function predictSample(array $sample)
71
    {
72
        $distances = $this->kNeighborsDistances($sample);
73
74
        $predictions = array_combine(array_values($this->labels), array_fill(0, count($this->labels), 0));
75
76
        foreach ($distances as $index => $distance) {
77
            ++$predictions[$this->labels[$index]];
78
        }
79
80
        arsort($predictions);
81
        reset($predictions);
82
83
        return key($predictions);
84
    }
85
86
    /**
87
     * @param array $sample
88
     *
89
     * @return array
90
     *
91
     * @throws \Phpml\Exception\InvalidArgumentException
92
     */
93
    private function kNeighborsDistances(array $sample): array
94
    {
95
        $distances = [];
96
97
        foreach ($this->samples as $index => $neighbor) {
98
            $distances[$index] = Distance::euclidean($sample, $neighbor);
99
        }
100
101
        asort($distances);
102
103
        return array_slice($distances, 0, $this->k, true);
104
    }
105
}
106