Test Setup Failed
Push — master ( f6aa1a...8544cf )
by Arkadiusz
02:28
created

src/Helper/Optimizer/StochasticGD.php (3 issues)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper\Optimizer;
6
7
use Closure;
8
use Phpml\Exception\InvalidArgumentException;
9
use Phpml\Exception\InvalidOperationException;
10
11
/**
12
 * Stochastic Gradient Descent optimization method
13
 * to find a solution for the equation A.ϴ = y where
14
 *  A (samples) and y (targets) are known and ϴ is unknown.
15
 */
16
class StochasticGD extends Optimizer
17
{
18
    /**
19
     * A (samples)
20
     *
21
     * @var array
22
     */
23
    protected $samples = [];
24
25
    /**
26
     * y (targets)
27
     *
28
     * @var array
29
     */
30
    protected $targets = [];
31
32
    /**
33
     * Callback function to get the gradient and cost value
34
     * for a specific set of theta (ϴ) and a pair of sample & target
35
     *
36
     * @var \Closure|null
37
     */
38
    protected $gradientCb;
39
40
    /**
41
     * Maximum number of iterations used to train the model
42
     *
43
     * @var int
44
     */
45
    protected $maxIterations = 1000;
46
47
    /**
48
     * Learning rate is used to control the speed of the optimization.<br>
49
     *
50
     * Larger values of lr may overshoot the optimum or even cause divergence
51
     * while small values slows down the convergence and increases the time
52
     * required for the training
53
     *
54
     * @var float
55
     */
56
    protected $learningRate = 0.001;
57
58
    /**
59
     * Minimum amount of change in the weights and error values
60
     * between iterations that needs to be obtained to continue the training
61
     *
62
     * @var float
63
     */
64
    protected $threshold = 1e-4;
65
66
    /**
67
     * Enable/Disable early stopping by checking the weight & cost values
68
     * to see whether they changed large enough to continue the optimization
69
     *
70
     * @var bool
71
     */
72
    protected $enableEarlyStop = true;
73
74
    /**
75
     * List of values obtained by evaluating the cost function at each iteration
76
     * of the algorithm
77
     *
78
     * @var array
79
     */
80
    protected $costValues = [];
81
82
    /**
83
     * Initializes the SGD optimizer for the given number of dimensions
84
     */
85
    public function __construct(int $dimensions)
86
    {
87
        // Add one more dimension for the bias
88
        parent::__construct($dimensions + 1);
89
90
        $this->dimensions = $dimensions;
91
    }
92
93
    public function setTheta(array $theta): Optimizer
94
    {
95
        if (count($theta) !== $this->dimensions + 1) {
96
            throw new InvalidArgumentException(sprintf('Number of values in the weights array should be %s', $this->dimensions + 1));
97
        }
98
99
        $this->theta = $theta;
100
101
        return $this;
102
    }
103
104
    /**
105
     * Sets minimum value for the change in the theta values
106
     * between iterations to continue the iterations.<br>
107
     *
108
     * If change in the theta is less than given value then the
109
     * algorithm will stop training
110
     *
111
     * @return $this
112
     */
113
    public function setChangeThreshold(float $threshold = 1e-5)
114
    {
115
        $this->threshold = $threshold;
116
117
        return $this;
118
    }
119
120
    /**
121
     * Enable/Disable early stopping by checking at each iteration
122
     * whether changes in theta or cost value are not large enough
123
     *
124
     * @return $this
125
     */
126
    public function setEarlyStop(bool $enable = true)
127
    {
128
        $this->enableEarlyStop = $enable;
129
130
        return $this;
131
    }
132
133
    /**
134
     * @return $this
135
     */
136
    public function setLearningRate(float $learningRate)
137
    {
138
        $this->learningRate = $learningRate;
139
140
        return $this;
141
    }
142
143
    /**
144
     * @return $this
145
     */
146
    public function setMaxIterations(int $maxIterations)
147
    {
148
        $this->maxIterations = $maxIterations;
149
150
        return $this;
151
    }
152
153
    /**
154
     * Optimization procedure finds the unknow variables for the equation A.ϴ = y
155
     * for the given samples (A) and targets (y).<br>
156
     *
157
     * The cost function to minimize and the gradient of the function are to be
158
     * handled by the callback function provided as the third parameter of the method.
159
     */
160
    public function runOptimization(array $samples, array $targets, Closure $gradientCb): array
161
    {
162
        $this->samples = $samples;
163
        $this->targets = $targets;
164
        $this->gradientCb = $gradientCb;
165
166
        $currIter = 0;
167
        $bestTheta = null;
168
        $bestScore = 0.0;
169
        $this->costValues = [];
170
171
        while ($this->maxIterations > $currIter++) {
172
            $theta = $this->theta;
173
174
            // Update the guess
175
            $cost = $this->updateTheta();
176
177
            // Save the best theta in the "pocket" so that
178
            // any future set of theta worse than this will be disregarded
179
            if ($bestTheta === null || $cost <= $bestScore) {
180
                $bestTheta = $theta;
181
                $bestScore = $cost;
182
            }
183
184
            // Add the cost value for this iteration to the list
185
            $this->costValues[] = $cost;
186
187
            // Check for early stop
188
            if ($this->enableEarlyStop && $this->earlyStop($theta)) {
189
                break;
190
            }
191
        }
192
193
        $this->clear();
194
195
        // Solution in the pocket is better than or equal to the last state
196
        // so, we use this solution
197
        return $this->theta = (array) $bestTheta;
198
    }
199
200
    /**
201
     * Returns the list of cost values for each iteration executed in
202
     * last run of the optimization
203
     */
204
    public function getCostValues(): array
205
    {
206
        return $this->costValues;
207
    }
208
209
    protected function updateTheta(): float
210
    {
211
        $jValue = 0.0;
212
        $theta = $this->theta;
213
214
        if ($this->gradientCb === null) {
215
            throw new InvalidOperationException('Gradient callback is not defined');
216
        }
217
218
        foreach ($this->samples as $index => $sample) {
219
            $target = $this->targets[$index];
220
221
            $result = ($this->gradientCb)($theta, $sample, $target);
222
223
            [$error, $gradient, $penalty] = array_pad($result, 3, 0);
0 ignored issues
show
The variable $error does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
The variable $gradient does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
The variable $penalty does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
224
225
            // Update bias
226
            $this->theta[0] -= $this->learningRate * $gradient;
227
228
            // Update other values
229
            for ($i = 1; $i <= $this->dimensions; ++$i) {
230
                $this->theta[$i] -= $this->learningRate *
231
                    ($gradient * $sample[$i - 1] + $penalty * $this->theta[$i]);
232
            }
233
234
            // Sum error rate
235
            $jValue += $error;
236
        }
237
238
        return $jValue / count($this->samples);
239
    }
240
241
    /**
242
     * Checks if the optimization is not effective enough and can be stopped
243
     * in case large enough changes in the solution do not happen
244
     */
245
    protected function earlyStop(array $oldTheta): bool
246
    {
247
        // Check for early stop: No change larger than threshold (default 1e-5)
248
        $diff = array_map(
249
            function ($w1, $w2) {
250
                return abs($w1 - $w2) > $this->threshold ? 1 : 0;
251
            },
252
            $oldTheta,
253
            $this->theta
254
        );
255
256
        if (array_sum($diff) == 0) {
257
            return true;
258
        }
259
260
        // Check if the last two cost values are almost the same
261
        $costs = array_slice($this->costValues, -2);
262
        if (count($costs) === 2 && abs($costs[1] - $costs[0]) < $this->threshold) {
263
            return true;
264
        }
265
266
        return false;
267
    }
268
269
    /**
270
     * Clears the optimizer internal vars after the optimization process.
271
     */
272
    protected function clear(): void
273
    {
274
        $this->samples = [];
275
        $this->targets = [];
276
        $this->gradientCb = null;
277
    }
278
}
279