Passed
Push — master ( e83f7b...d953ef )
by Arkadiusz
03:28
created

src/Phpml/Helper/Optimizer/StochasticGD.php (3 issues)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper\Optimizer;
6
7
use Closure;
8
9
/**
10
 * Stochastic Gradient Descent optimization method
11
 * to find a solution for the equation A.ϴ = y where
12
 *  A (samples) and y (targets) are known and ϴ is unknown.
13
 */
14
class StochasticGD extends Optimizer
15
{
16
    /**
17
     * A (samples)
18
     *
19
     * @var array
20
     */
21
    protected $samples = [];
22
23
    /**
24
     * y (targets)
25
     *
26
     * @var array
27
     */
28
    protected $targets = [];
29
30
    /**
31
     * Callback function to get the gradient and cost value
32
     * for a specific set of theta (ϴ) and a pair of sample & target
33
     *
34
     * @var \Closure|null
35
     */
36
    protected $gradientCb = null;
37
38
    /**
39
     * Maximum number of iterations used to train the model
40
     *
41
     * @var int
42
     */
43
    protected $maxIterations = 1000;
44
45
    /**
46
     * Learning rate is used to control the speed of the optimization.<br>
47
     *
48
     * Larger values of lr may overshoot the optimum or even cause divergence
49
     * while small values slows down the convergence and increases the time
50
     * required for the training
51
     *
52
     * @var float
53
     */
54
    protected $learningRate = 0.001;
55
56
    /**
57
     * Minimum amount of change in the weights and error values
58
     * between iterations that needs to be obtained to continue the training
59
     *
60
     * @var float
61
     */
62
    protected $threshold = 1e-4;
63
64
    /**
65
     * Enable/Disable early stopping by checking the weight & cost values
66
     * to see whether they changed large enough to continue the optimization
67
     *
68
     * @var bool
69
     */
70
    protected $enableEarlyStop = true;
71
72
    /**
73
     * List of values obtained by evaluating the cost function at each iteration
74
     * of the algorithm
75
     *
76
     * @var array
77
     */
78
    protected $costValues = [];
79
80
    /**
81
     * Initializes the SGD optimizer for the given number of dimensions
82
     */
83
    public function __construct(int $dimensions)
84
    {
85
        // Add one more dimension for the bias
86
        parent::__construct($dimensions + 1);
87
88
        $this->dimensions = $dimensions;
89
    }
90
91
    /**
92
     * Sets minimum value for the change in the theta values
93
     * between iterations to continue the iterations.<br>
94
     *
95
     * If change in the theta is less than given value then the
96
     * algorithm will stop training
97
     *
98
     * @return $this
99
     */
100
    public function setChangeThreshold(float $threshold = 1e-5)
101
    {
102
        $this->threshold = $threshold;
103
104
        return $this;
105
    }
106
107
    /**
108
     * Enable/Disable early stopping by checking at each iteration
109
     * whether changes in theta or cost value are not large enough
110
     *
111
     * @return $this
112
     */
113
    public function setEarlyStop(bool $enable = true)
114
    {
115
        $this->enableEarlyStop = $enable;
116
117
        return $this;
118
    }
119
120
    /**
121
     * @return $this
122
     */
123
    public function setLearningRate(float $learningRate)
124
    {
125
        $this->learningRate = $learningRate;
126
127
        return $this;
128
    }
129
130
    /**
131
     * @return $this
132
     */
133
    public function setMaxIterations(int $maxIterations)
134
    {
135
        $this->maxIterations = $maxIterations;
136
137
        return $this;
138
    }
139
140
    /**
141
     * Optimization procedure finds the unknow variables for the equation A.ϴ = y
142
     * for the given samples (A) and targets (y).<br>
143
     *
144
     * The cost function to minimize and the gradient of the function are to be
145
     * handled by the callback function provided as the third parameter of the method.
146
     */
147
    public function runOptimization(array $samples, array $targets, Closure $gradientCb): ?array
148
    {
149
        $this->samples = $samples;
150
        $this->targets = $targets;
151
        $this->gradientCb = $gradientCb;
152
153
        $currIter = 0;
154
        $bestTheta = null;
155
        $bestScore = 0.0;
156
        $this->costValues = [];
157
158
        while ($this->maxIterations > $currIter++) {
159
            $theta = $this->theta;
160
161
            // Update the guess
162
            $cost = $this->updateTheta();
163
164
            // Save the best theta in the "pocket" so that
165
            // any future set of theta worse than this will be disregarded
166
            if ($bestTheta == null || $cost <= $bestScore) {
167
                $bestTheta = $theta;
168
                $bestScore = $cost;
169
            }
170
171
            // Add the cost value for this iteration to the list
172
            $this->costValues[] = $cost;
173
174
            // Check for early stop
175
            if ($this->enableEarlyStop && $this->earlyStop($theta)) {
176
                break;
177
            }
178
        }
179
180
        $this->clear();
181
182
        // Solution in the pocket is better than or equal to the last state
183
        // so, we use this solution
184
        return $this->theta = (array) $bestTheta;
185
    }
186
187
    /**
188
     * Returns the list of cost values for each iteration executed in
189
     * last run of the optimization
190
     */
191
    public function getCostValues(): array
192
    {
193
        return $this->costValues;
194
    }
195
196
    protected function updateTheta(): float
197
    {
198
        $jValue = 0.0;
199
        $theta = $this->theta;
200
201
        foreach ($this->samples as $index => $sample) {
202
            $target = $this->targets[$index];
203
204
            $result = ($this->gradientCb)($theta, $sample, $target);
205
206
            [$error, $gradient, $penalty] = array_pad($result, 3, 0);
0 ignored issues
show
The variable $error does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
The variable $gradient does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
The variable $penalty does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
207
208
            // Update bias
209
            $this->theta[0] -= $this->learningRate * $gradient;
210
211
            // Update other values
212
            for ($i = 1; $i <= $this->dimensions; ++$i) {
213
                $this->theta[$i] -= $this->learningRate *
214
                    ($gradient * $sample[$i - 1] + $penalty * $this->theta[$i]);
215
            }
216
217
            // Sum error rate
218
            $jValue += $error;
219
        }
220
221
        return $jValue / count($this->samples);
222
    }
223
224
    /**
225
     * Checks if the optimization is not effective enough and can be stopped
226
     * in case large enough changes in the solution do not happen
227
     */
228
    protected function earlyStop(array $oldTheta): bool
229
    {
230
        // Check for early stop: No change larger than threshold (default 1e-5)
231
        $diff = array_map(
232
            function ($w1, $w2) {
233
                return abs($w1 - $w2) > $this->threshold ? 1 : 0;
234
            },
235
            $oldTheta,
236
            $this->theta
237
        );
238
239
        if (array_sum($diff) == 0) {
240
            return true;
241
        }
242
243
        // Check if the last two cost values are almost the same
244
        $costs = array_slice($this->costValues, -2);
245
        if (count($costs) == 2 && abs($costs[1] - $costs[0]) < $this->threshold) {
246
            return true;
247
        }
248
249
        return false;
250
    }
251
252
    /**
253
     * Clears the optimizer internal vars after the optimization process.
254
     */
255
    protected function clear(): void
256
    {
257
        $this->samples = [];
258
        $this->targets = [];
259
        $this->gradientCb = null;
260
    }
261
}
262