StochasticGD   A
last analyzed

Complexity

Total Complexity 24

Size/Duplication

Total Lines 261
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 24
eloc 70
dl 0
loc 261
rs 10
c 0
b 0
f 0

11 Methods

Rating   Name   Duplication   Size   Complexity  
A __construct() 0 6 1
A clear() 0 5 1
A getCostValues() 0 3 1
A updateTheta() 0 30 4
A setTheta() 0 9 2
A setChangeThreshold() 0 5 1
A earlyStop() 0 22 5
A setLearningRate() 0 5 1
A runOptimization() 0 38 6
A setEarlyStop() 0 5 1
A setMaxIterations() 0 5 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper\Optimizer;
6
7
use Closure;
8
use Phpml\Exception\InvalidArgumentException;
9
use Phpml\Exception\InvalidOperationException;
10
11
/**
12
 * Stochastic Gradient Descent optimization method
13
 * to find a solution for the equation A.ϴ = y where
14
 *  A (samples) and y (targets) are known and ϴ is unknown.
15
 */
16
class StochasticGD extends Optimizer
17
{
18
    /**
19
     * A (samples)
20
     *
21
     * @var array
22
     */
23
    protected $samples = [];
24
25
    /**
26
     * y (targets)
27
     *
28
     * @var array
29
     */
30
    protected $targets = [];
31
32
    /**
33
     * Callback function to get the gradient and cost value
34
     * for a specific set of theta (ϴ) and a pair of sample & target
35
     *
36
     * @var \Closure|null
37
     */
38
    protected $gradientCb;
39
40
    /**
41
     * Maximum number of iterations used to train the model
42
     *
43
     * @var int
44
     */
45
    protected $maxIterations = 1000;
46
47
    /**
48
     * Learning rate is used to control the speed of the optimization.<br>
49
     *
50
     * Larger values of lr may overshoot the optimum or even cause divergence
51
     * while small values slows down the convergence and increases the time
52
     * required for the training
53
     *
54
     * @var float
55
     */
56
    protected $learningRate = 0.001;
57
58
    /**
59
     * Minimum amount of change in the weights and error values
60
     * between iterations that needs to be obtained to continue the training
61
     *
62
     * @var float
63
     */
64
    protected $threshold = 1e-4;
65
66
    /**
67
     * Enable/Disable early stopping by checking the weight & cost values
68
     * to see whether they changed large enough to continue the optimization
69
     *
70
     * @var bool
71
     */
72
    protected $enableEarlyStop = true;
73
74
    /**
75
     * List of values obtained by evaluating the cost function at each iteration
76
     * of the algorithm
77
     *
78
     * @var array
79
     */
80
    protected $costValues = [];
81
82
    /**
83
     * Initializes the SGD optimizer for the given number of dimensions
84
     */
85
    public function __construct(int $dimensions)
86
    {
87
        // Add one more dimension for the bias
88
        parent::__construct($dimensions + 1);
89
90
        $this->dimensions = $dimensions;
91
    }
92
93
    public function setTheta(array $theta): Optimizer
94
    {
95
        if (count($theta) !== $this->dimensions + 1) {
96
            throw new InvalidArgumentException(sprintf('Number of values in the weights array should be %s', $this->dimensions + 1));
97
        }
98
99
        $this->theta = $theta;
100
101
        return $this;
102
    }
103
104
    /**
105
     * Sets minimum value for the change in the theta values
106
     * between iterations to continue the iterations.<br>
107
     *
108
     * If change in the theta is less than given value then the
109
     * algorithm will stop training
110
     *
111
     * @return $this
112
     */
113
    public function setChangeThreshold(float $threshold = 1e-5)
114
    {
115
        $this->threshold = $threshold;
116
117
        return $this;
118
    }
119
120
    /**
121
     * Enable/Disable early stopping by checking at each iteration
122
     * whether changes in theta or cost value are not large enough
123
     *
124
     * @return $this
125
     */
126
    public function setEarlyStop(bool $enable = true)
127
    {
128
        $this->enableEarlyStop = $enable;
129
130
        return $this;
131
    }
132
133
    /**
134
     * @return $this
135
     */
136
    public function setLearningRate(float $learningRate)
137
    {
138
        $this->learningRate = $learningRate;
139
140
        return $this;
141
    }
142
143
    /**
144
     * @return $this
145
     */
146
    public function setMaxIterations(int $maxIterations)
147
    {
148
        $this->maxIterations = $maxIterations;
149
150
        return $this;
151
    }
152
153
    /**
154
     * Optimization procedure finds the unknow variables for the equation A.ϴ = y
155
     * for the given samples (A) and targets (y).<br>
156
     *
157
     * The cost function to minimize and the gradient of the function are to be
158
     * handled by the callback function provided as the third parameter of the method.
159
     */
160
    public function runOptimization(array $samples, array $targets, Closure $gradientCb): array
161
    {
162
        $this->samples = $samples;
163
        $this->targets = $targets;
164
        $this->gradientCb = $gradientCb;
165
166
        $currIter = 0;
167
        $bestTheta = null;
168
        $bestScore = 0.0;
169
        $this->costValues = [];
170
171
        while ($this->maxIterations > $currIter++) {
172
            $theta = $this->theta;
173
174
            // Update the guess
175
            $cost = $this->updateTheta();
176
177
            // Save the best theta in the "pocket" so that
178
            // any future set of theta worse than this will be disregarded
179
            if ($bestTheta === null || $cost <= $bestScore) {
180
                $bestTheta = $theta;
181
                $bestScore = $cost;
182
            }
183
184
            // Add the cost value for this iteration to the list
185
            $this->costValues[] = $cost;
186
187
            // Check for early stop
188
            if ($this->enableEarlyStop && $this->earlyStop($theta)) {
189
                break;
190
            }
191
        }
192
193
        $this->clear();
194
195
        // Solution in the pocket is better than or equal to the last state
196
        // so, we use this solution
197
        return $this->theta = (array) $bestTheta;
198
    }
199
200
    /**
201
     * Returns the list of cost values for each iteration executed in
202
     * last run of the optimization
203
     */
204
    public function getCostValues(): array
205
    {
206
        return $this->costValues;
207
    }
208
209
    protected function updateTheta(): float
210
    {
211
        $jValue = 0.0;
212
        $theta = $this->theta;
213
214
        if ($this->gradientCb === null) {
215
            throw new InvalidOperationException('Gradient callback is not defined');
216
        }
217
218
        foreach ($this->samples as $index => $sample) {
219
            $target = $this->targets[$index];
220
221
            $result = ($this->gradientCb)($theta, $sample, $target);
222
223
            [$error, $gradient, $penalty] = array_pad($result, 3, 0);
224
225
            // Update bias
226
            $this->theta[0] -= $this->learningRate * $gradient;
227
228
            // Update other values
229
            for ($i = 1; $i <= $this->dimensions; ++$i) {
230
                $this->theta[$i] -= $this->learningRate *
231
                    ($gradient * $sample[$i - 1] + $penalty * $this->theta[$i]);
232
            }
233
234
            // Sum error rate
235
            $jValue += $error;
236
        }
237
238
        return $jValue / count($this->samples);
239
    }
240
241
    /**
242
     * Checks if the optimization is not effective enough and can be stopped
243
     * in case large enough changes in the solution do not happen
244
     */
245
    protected function earlyStop(array $oldTheta): bool
246
    {
247
        // Check for early stop: No change larger than threshold (default 1e-5)
248
        $diff = array_map(
249
            function ($w1, $w2) {
250
                return abs($w1 - $w2) > $this->threshold ? 1 : 0;
251
            },
252
            $oldTheta,
253
            $this->theta
254
        );
255
256
        if (array_sum($diff) == 0) {
257
            return true;
258
        }
259
260
        // Check if the last two cost values are almost the same
261
        $costs = array_slice($this->costValues, -2);
262
        if (count($costs) === 2 && abs($costs[1] - $costs[0]) < $this->threshold) {
263
            return true;
264
        }
265
266
        return false;
267
    }
268
269
    /**
270
     * Clears the optimizer internal vars after the optimization process.
271
     */
272
    protected function clear(): void
273
    {
274
        $this->samples = [];
275
        $this->targets = [];
276
        $this->gradientCb = null;
277
    }
278
}
279