Passed
Push — master ( c44f3b...492344 )
by Arkadiusz
02:48
created

GD::runOptimization()   B

Complexity

Conditions 3
Paths 3

Size

Total Lines 27
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 27
rs 8.8571
c 0
b 0
f 0
cc 3
eloc 15
nc 3
nop 3
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper\Optimizer;
6
7
/**
8
 * Batch version of Gradient Descent to optimize the weights
9
 * of a classifier given samples, targets and the objective function to minimize
10
 */
11
class GD extends StochasticGD
12
{
13
    /**
14
     * Number of samples given
15
     *
16
     * @var int
17
     */
18
    protected $sampleCount;
19
20
    /**
21
     * @param array $samples
22
     * @param array $targets
23
     * @param \Closure $gradientCb
24
     *
25
     * @return array
26
     */
27
    public function runOptimization(array $samples, array $targets, \Closure $gradientCb)
28
    {
29
        $this->samples = $samples;
30
        $this->targets = $targets;
31
        $this->gradientCb = $gradientCb;
32
        $this->sampleCount = count($this->samples);
33
34
        // Batch learning is executed:
35
        $currIter = 0;
36
        $this->costValues = [];
37
        while ($this->maxIterations > $currIter++) {
38
            $theta = $this->theta;
39
40
            // Calculate update terms for each sample
41
            list($errors, $updates, $totalPenalty) = $this->gradient($theta);
42
43
            $this->updateWeightsWithUpdates($updates, $totalPenalty);
44
45
            $this->costValues[] = array_sum($errors)/$this->sampleCount;
46
47
            if ($this->earlyStop($theta)) {
48
                break;
49
            }
50
        }
51
52
        return $this->theta;
53
    }
54
55
    /**
56
     * Calculates gradient, cost function and penalty term for each sample
57
     * then returns them as an array of values
58
     *
59
     * @param array $theta
60
     *
61
     * @return array
62
     */
63
    protected function gradient(array $theta)
64
    {
65
        $costs = [];
66
        $gradient= [];
67
        $totalPenalty = 0;
68
69
        foreach ($this->samples as $index => $sample) {
70
            $target = $this->targets[$index];
71
72
            $result = ($this->gradientCb)($theta, $sample, $target);
73
            list($cost, $grad, $penalty) = array_pad($result, 3, 0);
74
75
            $costs[] = $cost;
76
            $gradient[]= $grad;
77
            $totalPenalty += $penalty;
78
        }
79
80
        $totalPenalty /= $this->sampleCount;
81
82
        return [$costs, $gradient, $totalPenalty];
83
    }
84
85
    /**
86
     * @param array $updates
87
     * @param float $penalty
88
     */
89
    protected function updateWeightsWithUpdates(array $updates, float $penalty)
90
    {
91
        // Updates all weights at once
92
        for ($i=0; $i <= $this->dimensions; $i++) {
93
            if ($i == 0) {
94
                $this->theta[0] -= $this->learningRate * array_sum($updates);
95
            } else {
96
                $col = array_column($this->samples, $i - 1);
97
98
                $error = 0;
99
                foreach ($col as $index => $val) {
100
                    $error += $val * $updates[$index];
101
                }
102
103
                $this->theta[$i] -= $this->learningRate *
104
                    ($error + $penalty * $this->theta[$i]);
105
            }
106
        }
107
    }
108
}
109