Completed
Pull Request — master (#37)
by
unknown
02:43
created

KNearestNeighbors   A

Complexity

Total Complexity 8

Size/Duplication

Total Lines 96
Duplicated Lines 0 %

Coupling/Cohesion

Components 1
Dependencies 4

Importance

Changes 0
Metric Value
wmc 8
lcom 1
cbo 4
dl 0
loc 96
rs 10
c 0
b 0
f 0

5 Methods

Rating   Name   Duplication   Size   Complexity  
A __construct() 0 11 2
A serialize() 0 9 1
A unserialize() 0 8 1
A predictSample() 0 15 2
A kNeighborsDistances() 0 12 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\Trainable;
9
use Phpml\Math\Distance;
10
use Phpml\Math\Distance\Euclidean;
11
12
class KNearestNeighbors implements Classifier, \Serializable
13
{
14
    use Trainable, Predictable;
15
16
    /**
17
     * @var int
18
     */
19
    private $k;
20
21
    /**
22
     * @var Distance
23
     */
24
    private $distanceMetric;
25
26
    /**
27
     * @param int           $k
28
     * @param Distance|null $distanceMetric (if null then Euclidean distance as default)
29
     */
30
    public function __construct(int $k = 3, Distance $distanceMetric = null)
31
    {
32
        if (null === $distanceMetric) {
33
            $distanceMetric = new Euclidean();
34
        }
35
36
        $this->k = $k;
37
        $this->samples = [];
38
        $this->targets = [];
39
        $this->distanceMetric = $distanceMetric;
40
    }
41
42
    /**
43
     * @return string The serialized object
44
     */
45
    public function serialize()
46
    {
47
        return serialize([
48
            'k' => $this->k,
49
            'distanceMetric' => $this->distanceMetric,
50
            'samples' => $this->samples,
51
            'targets' => $this->targets,
52
        ]);
53
    }
54
55
    /**
56
     * @param string $data The serialized object
57
     */
58
    public function unserialize($data)
59
    {
60
        $data = unserialize($data);
61
        $this->k = $data['k'];
62
        $this->distanceMetric = $data['distanceMetric'];
63
        $this->samples = $data['samples'];
64
        $this->targets = $data['targets'];
65
    }
66
67
    /**
68
     * @param array $sample
69
     *
70
     * @return mixed
71
     */
72
    protected function predictSample(array $sample)
73
    {
74
        $distances = $this->kNeighborsDistances($sample);
75
76
        $predictions = array_combine(array_values($this->targets), array_fill(0, count($this->targets), 0));
77
78
        foreach ($distances as $index => $distance) {
79
            ++$predictions[$this->targets[$index]];
80
        }
81
82
        arsort($predictions);
83
        reset($predictions);
84
85
        return key($predictions);
86
    }
87
88
    /**
89
     * @param array $sample
90
     *
91
     * @return array
92
     *
93
     * @throws \Phpml\Exception\InvalidArgumentException
94
     */
95
    private function kNeighborsDistances(array $sample)
96
    {
97
        $distances = [];
98
99
        foreach ($this->samples as $index => $neighbor) {
100
            $distances[$index] = $this->distanceMetric->distance($sample, $neighbor);
101
        }
102
103
        asort($distances);
104
105
        return array_slice($distances, 0, $this->k, true);
106
    }
107
}
108