Passed
Pull Request — master (#34)
by Benjamin
13:00
created

Algorithm::fit()   A

Complexity

Conditions 4
Paths 2

Size

Total Lines 23
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 9
nc 2
nop 3
dl 0
loc 23
rs 9.9666
c 0
b 0
f 0
1
<?php
2
3
namespace Kmeans;
4
5
use Kmeans\Cluster;
6
use Kmeans\ClusterCollection;
7
use Kmeans\Interfaces\AlgorithmInterface;
8
use Kmeans\Interfaces\ClusterCollectionInterface;
9
use Kmeans\Interfaces\ClusterInterface;
10
use Kmeans\Interfaces\InitializationSchemeInterface;
11
use Kmeans\Interfaces\PointCollectionInterface;
12
use Kmeans\Interfaces\PointInterface;
13
14
abstract class Algorithm implements AlgorithmInterface
15
{
16
    private InitializationSchemeInterface $initScheme;
17
18
    /**
19
     * @var array<callable>
20
     */
21
    private array $iterationCallbacks = [];
22
23
    public function __construct(InitializationSchemeInterface $initScheme)
24
    {
25
        $this->initScheme = $initScheme;
26
    }
27
28
    public function registerIterationCallback(callable $callback): void
29
    {
30
        $this->iterationCallbacks[] = $callback;
31
    }
32
33
    public function fit(
34
        PointCollectionInterface $points,
35
        int $nClusters,
36
        ?int $maxIter = null
37
    ): ClusterCollectionInterface {
38
        $maxIter ??= INF;
39
40
        if ($maxIter < 1) {
41
            throw new \UnexpectedValueException(
42
                "Invalid maximum number of iterations: {$maxIter}"
43
            );
44
        }
45
46
        // initialize clusters
47
        $clusters = $this->initScheme->initializeClusters($points, $nClusters);
48
49
        // iterate until convergence is reached
50
        do {
51
            $this->invokeIterationCallbacks($clusters);
52
        } while ($this->iterate($clusters) && --$maxIter);
53
54
        // clustering is done.
55
        return $clusters;
56
    }
57
58
    protected function iterate(ClusterCollectionInterface $clusters): bool
59
    {
60
        /** @var \SplObjectStorage<ClusterInterface, null> */
61
        $changed = new \SplObjectStorage();
62
63
        // calculate proximity amongst points and clusters
64
        foreach ($clusters as $cluster) {
65
            foreach ($cluster->getPoints() as $point) {
66
                // find the closest cluster
67
                $closest = $this->getClosestCluster($clusters, $point);
68
69
                if ($closest !== $cluster) {
70
                    // move the point from its current cluster to its closest
71
                    $cluster->detach($point);
72
                    $closest->attach($point);
73
74
                    // flag both clusters as changed
75
                    $changed->attach($cluster);
76
                    $changed->attach($closest);
77
                }
78
            }
79
        }
80
81
        // update changed clusters' centroid
82
        foreach ($changed as $cluster) {
83
            $cluster->setCentroid($this->findCentroid($cluster->getPoints()));
84
        }
85
86
        // return true if something changed during this iteration
87
        return count($changed) > 0;
88
    }
89
90
    private function getClosestCluster(ClusterCollectionInterface $clusters, PointInterface $point): ClusterInterface
91
    {
92
        $min = INF;
93
        $closest = null;
94
95
        foreach ($clusters as $cluster) {
96
            $distance = $this->getDistanceBetween($point, $cluster->getCentroid());
97
98
            if ($distance < $min) {
99
                $min = $distance;
100
                $closest = $cluster;
101
            }
102
        }
103
104
        assert($closest !== null);
105
        return $closest;
106
    }
107
108
    private function invokeIterationCallbacks(ClusterCollectionInterface $clusters): void
109
    {
110
        foreach ($this->iterationCallbacks as $callback) {
111
            $callback($this, $clusters);
112
        }
113
    }
114
}
115