Passed
Push — master ( c0463a...e1854d )
by Arkadiusz
02:54
created

GD   A

Complexity

Total Complexity 10

Size/Duplication

Total Lines 111
Duplicated Lines 0 %

Coupling/Cohesion

Components 1
Dependencies 1

Importance

Changes 0
Metric Value
wmc 10
lcom 1
cbo 1
dl 0
loc 111
rs 10
c 0
b 0
f 0

4 Methods

Rating   Name   Duplication   Size   Complexity  
B runOptimization() 0 29 3
A gradient() 0 21 2
A updateWeightsWithUpdates() 0 19 4
A clear() 0 5 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper\Optimizer;
6
7
/**
8
 * Batch version of Gradient Descent to optimize the weights
9
 * of a classifier given samples, targets and the objective function to minimize
10
 */
11
class GD extends StochasticGD
12
{
13
    /**
14
     * Number of samples given
15
     *
16
     * @var int
17
     */
18
    protected $sampleCount = null;
19
20
    /**
21
     * @param array $samples
22
     * @param array $targets
23
     * @param \Closure $gradientCb
24
     *
25
     * @return array
26
     */
27
    public function runOptimization(array $samples, array $targets, \Closure $gradientCb)
28
    {
29
        $this->samples = $samples;
30
        $this->targets = $targets;
31
        $this->gradientCb = $gradientCb;
32
        $this->sampleCount = count($this->samples);
33
34
        // Batch learning is executed:
35
        $currIter = 0;
36
        $this->costValues = [];
37
        while ($this->maxIterations > $currIter++) {
38
            $theta = $this->theta;
39
40
            // Calculate update terms for each sample
41
            list($errors, $updates, $totalPenalty) = $this->gradient($theta);
42
43
            $this->updateWeightsWithUpdates($updates, $totalPenalty);
44
45
            $this->costValues[] = array_sum($errors)/$this->sampleCount;
46
47
            if ($this->earlyStop($theta)) {
48
                break;
49
            }
50
        }
51
52
        $this->clear();
53
54
        return $this->theta;
55
    }
56
57
    /**
58
     * Calculates gradient, cost function and penalty term for each sample
59
     * then returns them as an array of values
60
     *
61
     * @param array $theta
62
     *
63
     * @return array
64
     */
65
    protected function gradient(array $theta)
66
    {
67
        $costs = [];
68
        $gradient= [];
69
        $totalPenalty = 0;
70
71
        foreach ($this->samples as $index => $sample) {
72
            $target = $this->targets[$index];
73
74
            $result = ($this->gradientCb)($theta, $sample, $target);
75
            list($cost, $grad, $penalty) = array_pad($result, 3, 0);
76
77
            $costs[] = $cost;
78
            $gradient[]= $grad;
79
            $totalPenalty += $penalty;
80
        }
81
82
        $totalPenalty /= $this->sampleCount;
83
84
        return [$costs, $gradient, $totalPenalty];
85
    }
86
87
    /**
88
     * @param array $updates
89
     * @param float $penalty
90
     */
91
    protected function updateWeightsWithUpdates(array $updates, float $penalty)
92
    {
93
        // Updates all weights at once
94
        for ($i=0; $i <= $this->dimensions; $i++) {
95
            if ($i == 0) {
96
                $this->theta[0] -= $this->learningRate * array_sum($updates);
97
            } else {
98
                $col = array_column($this->samples, $i - 1);
99
100
                $error = 0;
101
                foreach ($col as $index => $val) {
102
                    $error += $val * $updates[$index];
103
                }
104
105
                $this->theta[$i] -= $this->learningRate *
106
                    ($error + $penalty * $this->theta[$i]);
107
            }
108
        }
109
    }
110
111
    /**
112
     * Clears the optimizer internal vars after the optimization process.
113
     *
114
     * @return void
115
     */
116
    protected function clear()
117
    {
118
        $this->sampleCount = null;
119
        parent::clear();
120
    }
121
}
122