GD   A
last analyzed

Complexity

Total Complexity 11

Size/Duplication

Total Lines 96
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 11
eloc 43
dl 0
loc 96
rs 10
c 0
b 0
f 0

4 Methods

Rating   Name   Duplication   Size   Complexity  
A clear() 0 4 1
A gradient() 0 24 3
A runOptimization() 0 28 3
A updateWeightsWithUpdates() 0 16 4
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper\Optimizer;
6
7
use Closure;
8
use Phpml\Exception\InvalidOperationException;
9
10
/**
11
 * Batch version of Gradient Descent to optimize the weights
12
 * of a classifier given samples, targets and the objective function to minimize
13
 */
14
class GD extends StochasticGD
15
{
16
    /**
17
     * Number of samples given
18
     *
19
     * @var int|null
20
     */
21
    protected $sampleCount;
22
23
    public function runOptimization(array $samples, array $targets, Closure $gradientCb): array
24
    {
25
        $this->samples = $samples;
26
        $this->targets = $targets;
27
        $this->gradientCb = $gradientCb;
28
        $this->sampleCount = count($this->samples);
29
30
        // Batch learning is executed:
31
        $currIter = 0;
32
        $this->costValues = [];
33
        while ($this->maxIterations > $currIter++) {
34
            $theta = $this->theta;
35
36
            // Calculate update terms for each sample
37
            [$errors, $updates, $totalPenalty] = $this->gradient($theta);
38
39
            $this->updateWeightsWithUpdates($updates, $totalPenalty);
40
41
            $this->costValues[] = array_sum($errors) / (int) $this->sampleCount;
42
43
            if ($this->earlyStop($theta)) {
44
                break;
45
            }
46
        }
47
48
        $this->clear();
49
50
        return $this->theta;
51
    }
52
53
    /**
54
     * Calculates gradient, cost function and penalty term for each sample
55
     * then returns them as an array of values
56
     */
57
    protected function gradient(array $theta): array
58
    {
59
        $costs = [];
60
        $gradient = [];
61
        $totalPenalty = 0;
62
63
        if ($this->gradientCb === null) {
64
            throw new InvalidOperationException('Gradient callback is not defined');
65
        }
66
67
        foreach ($this->samples as $index => $sample) {
68
            $target = $this->targets[$index];
69
70
            $result = ($this->gradientCb)($theta, $sample, $target);
71
            [$cost, $grad, $penalty] = array_pad($result, 3, 0);
72
73
            $costs[] = $cost;
74
            $gradient[] = $grad;
75
            $totalPenalty += $penalty;
76
        }
77
78
        $totalPenalty /= $this->sampleCount;
79
80
        return [$costs, $gradient, $totalPenalty];
81
    }
82
83
    protected function updateWeightsWithUpdates(array $updates, float $penalty): void
84
    {
85
        // Updates all weights at once
86
        for ($i = 0; $i <= $this->dimensions; ++$i) {
87
            if ($i === 0) {
88
                $this->theta[0] -= $this->learningRate * array_sum($updates);
89
            } else {
90
                $col = array_column($this->samples, $i - 1);
91
92
                $error = 0;
93
                foreach ($col as $index => $val) {
94
                    $error += $val * $updates[$index];
95
                }
96
97
                $this->theta[$i] -= $this->learningRate *
98
                    ($error + $penalty * $this->theta[$i]);
99
            }
100
        }
101
    }
102
103
    /**
104
     * Clears the optimizer internal vars after the optimization process.
105
     */
106
    protected function clear(): void
107
    {
108
        $this->sampleCount = null;
109
        parent::clear();
110
    }
111
}
112