Passed
Push — master ( e83f7b...d953ef )
by Arkadiusz
03:28
created

src/Phpml/Helper/Optimizer/GD.php (1 issue)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper\Optimizer;
6
7
use Closure;
8
9
/**
10
 * Batch version of Gradient Descent to optimize the weights
11
 * of a classifier given samples, targets and the objective function to minimize
12
 */
13
class GD extends StochasticGD
14
{
15
    /**
16
     * Number of samples given
17
     *
18
     * @var int|null
19
     */
20
    protected $sampleCount;
21
22
    public function runOptimization(array $samples, array $targets, Closure $gradientCb): array
23
    {
24
        $this->samples = $samples;
25
        $this->targets = $targets;
26
        $this->gradientCb = $gradientCb;
27
        $this->sampleCount = count($this->samples);
28
29
        // Batch learning is executed:
30
        $currIter = 0;
31
        $this->costValues = [];
32
        while ($this->maxIterations > $currIter++) {
33
            $theta = $this->theta;
34
35
            // Calculate update terms for each sample
36
            [$errors, $updates, $totalPenalty] = $this->gradient($theta);
37
38
            $this->updateWeightsWithUpdates($updates, $totalPenalty);
39
40
            $this->costValues[] = array_sum($errors) / $this->sampleCount;
41
42
            if ($this->earlyStop($theta)) {
43
                break;
44
            }
45
        }
46
47
        $this->clear();
48
49
        return $this->theta;
50
    }
51
52
    /**
53
     * Calculates gradient, cost function and penalty term for each sample
54
     * then returns them as an array of values
55
     */
56
    protected function gradient(array $theta): array
57
    {
58
        $costs = [];
59
        $gradient = [];
60
        $totalPenalty = 0;
61
62
        foreach ($this->samples as $index => $sample) {
63
            $target = $this->targets[$index];
64
65
            $result = ($this->gradientCb)($theta, $sample, $target);
66
            [$cost, $grad, $penalty] = array_pad($result, 3, 0);
67
68
            $costs[] = $cost;
69
            $gradient[] = $grad;
70
            $totalPenalty += $penalty;
0 ignored issues
show
The variable $penalty does not exist. Did you mean $totalPenalty?

This check looks for variables that are accessed but have not been defined. It raises an issue if it finds another variable that has a similar name.

The variable may have been renamed without also renaming all references.

Loading history...
71
        }
72
73
        $totalPenalty /= $this->sampleCount;
74
75
        return [$costs, $gradient, $totalPenalty];
76
    }
77
78
    protected function updateWeightsWithUpdates(array $updates, float $penalty): void
79
    {
80
        // Updates all weights at once
81
        for ($i = 0; $i <= $this->dimensions; ++$i) {
82
            if ($i === 0) {
83
                $this->theta[0] -= $this->learningRate * array_sum($updates);
84
            } else {
85
                $col = array_column($this->samples, $i - 1);
86
87
                $error = 0;
88
                foreach ($col as $index => $val) {
89
                    $error += $val * $updates[$index];
90
                }
91
92
                $this->theta[$i] -= $this->learningRate *
93
                    ($error + $penalty * $this->theta[$i]);
94
            }
95
        }
96
    }
97
98
    /**
99
     * Clears the optimizer internal vars after the optimization process.
100
     */
101
    protected function clear(): void
102
    {
103
        $this->sampleCount = null;
104
        parent::clear();
105
    }
106
}
107