KNearestNeighbors::predictSample()   A
last analyzed

Complexity

Conditions 2
Paths 2

Size

Total Lines 13
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 7
dl 0
loc 13
rs 10
c 0
b 0
f 0
cc 2
nc 2
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\Trainable;
9
use Phpml\Math\Distance;
10
use Phpml\Math\Distance\Euclidean;
11
12
class KNearestNeighbors implements Classifier
13
{
14
    use Trainable;
15
    use Predictable;
16
17
    /**
18
     * @var int
19
     */
20
    private $k;
21
22
    /**
23
     * @var Distance
24
     */
25
    private $distanceMetric;
26
27
    /**
28
     * @param Distance|null $distanceMetric (if null then Euclidean distance as default)
29
     */
30
    public function __construct(int $k = 3, ?Distance $distanceMetric = null)
31
    {
32
        if ($distanceMetric === null) {
33
            $distanceMetric = new Euclidean();
34
        }
35
36
        $this->k = $k;
37
        $this->samples = [];
38
        $this->targets = [];
39
        $this->distanceMetric = $distanceMetric;
40
    }
41
42
    /**
43
     * @return mixed
44
     */
45
    protected function predictSample(array $sample)
46
    {
47
        $distances = $this->kNeighborsDistances($sample);
48
        $predictions = (array) array_combine(array_values($this->targets), array_fill(0, count($this->targets), 0));
49
50
        foreach (array_keys($distances) as $index) {
51
            ++$predictions[$this->targets[$index]];
52
        }
53
54
        arsort($predictions);
55
        reset($predictions);
56
57
        return key($predictions);
58
    }
59
60
    /**
61
     * @throws \Phpml\Exception\InvalidArgumentException
62
     */
63
    private function kNeighborsDistances(array $sample): array
64
    {
65
        $distances = [];
66
67
        foreach ($this->samples as $index => $neighbor) {
68
            $distances[$index] = $this->distanceMetric->distance($sample, $neighbor);
69
        }
70
71
        asort($distances);
72
73
        return array_slice($distances, 0, $this->k, true);
74
    }
75
}
76