Passed
Push — master ( 55749c...a40c50 )
by
unknown
02:34
created

StochasticGD::setTheta()   A

Complexity

Conditions 2
Paths 2

Size

Total Lines 10
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 10
rs 9.4285
c 0
b 0
f 0
cc 2
eloc 5
nc 2
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper\Optimizer;
6
7
use Closure;
8
use Phpml\Exception\InvalidArgumentException;
9
10
/**
11
 * Stochastic Gradient Descent optimization method
12
 * to find a solution for the equation A.ϴ = y where
13
 *  A (samples) and y (targets) are known and ϴ is unknown.
14
 */
15
class StochasticGD extends Optimizer
16
{
17
    /**
18
     * A (samples)
19
     *
20
     * @var array
21
     */
22
    protected $samples = [];
23
24
    /**
25
     * y (targets)
26
     *
27
     * @var array
28
     */
29
    protected $targets = [];
30
31
    /**
32
     * Callback function to get the gradient and cost value
33
     * for a specific set of theta (ϴ) and a pair of sample & target
34
     *
35
     * @var \Closure|null
36
     */
37
    protected $gradientCb = null;
38
39
    /**
40
     * Maximum number of iterations used to train the model
41
     *
42
     * @var int
43
     */
44
    protected $maxIterations = 1000;
45
46
    /**
47
     * Learning rate is used to control the speed of the optimization.<br>
48
     *
49
     * Larger values of lr may overshoot the optimum or even cause divergence
50
     * while small values slows down the convergence and increases the time
51
     * required for the training
52
     *
53
     * @var float
54
     */
55
    protected $learningRate = 0.001;
56
57
    /**
58
     * Minimum amount of change in the weights and error values
59
     * between iterations that needs to be obtained to continue the training
60
     *
61
     * @var float
62
     */
63
    protected $threshold = 1e-4;
64
65
    /**
66
     * Enable/Disable early stopping by checking the weight & cost values
67
     * to see whether they changed large enough to continue the optimization
68
     *
69
     * @var bool
70
     */
71
    protected $enableEarlyStop = true;
72
73
    /**
74
     * List of values obtained by evaluating the cost function at each iteration
75
     * of the algorithm
76
     *
77
     * @var array
78
     */
79
    protected $costValues = [];
80
81
    /**
82
     * Initializes the SGD optimizer for the given number of dimensions
83
     */
84
    public function __construct(int $dimensions)
85
    {
86
        // Add one more dimension for the bias
87
        parent::__construct($dimensions + 1);
88
89
        $this->dimensions = $dimensions;
90
    }
91
92
    public function setTheta(array $theta)
93
    {
94
        if (count($theta) != $this->dimensions + 1) {
95
            throw new InvalidArgumentException(sprintf('Number of values in the weights array should be %s', $this->dimensions + 1));
96
        }
97
98
        $this->theta = $theta;
99
100
        return $this;
101
    }
102
103
    /**
104
     * Sets minimum value for the change in the theta values
105
     * between iterations to continue the iterations.<br>
106
     *
107
     * If change in the theta is less than given value then the
108
     * algorithm will stop training
109
     *
110
     * @return $this
111
     */
112
    public function setChangeThreshold(float $threshold = 1e-5)
113
    {
114
        $this->threshold = $threshold;
115
116
        return $this;
117
    }
118
119
    /**
120
     * Enable/Disable early stopping by checking at each iteration
121
     * whether changes in theta or cost value are not large enough
122
     *
123
     * @return $this
124
     */
125
    public function setEarlyStop(bool $enable = true)
126
    {
127
        $this->enableEarlyStop = $enable;
128
129
        return $this;
130
    }
131
132
    /**
133
     * @return $this
134
     */
135
    public function setLearningRate(float $learningRate)
136
    {
137
        $this->learningRate = $learningRate;
138
139
        return $this;
140
    }
141
142
    /**
143
     * @return $this
144
     */
145
    public function setMaxIterations(int $maxIterations)
146
    {
147
        $this->maxIterations = $maxIterations;
148
149
        return $this;
150
    }
151
152
    /**
153
     * Optimization procedure finds the unknow variables for the equation A.ϴ = y
154
     * for the given samples (A) and targets (y).<br>
155
     *
156
     * The cost function to minimize and the gradient of the function are to be
157
     * handled by the callback function provided as the third parameter of the method.
158
     */
159
    public function runOptimization(array $samples, array $targets, Closure $gradientCb): ?array
160
    {
161
        $this->samples = $samples;
162
        $this->targets = $targets;
163
        $this->gradientCb = $gradientCb;
164
165
        $currIter = 0;
166
        $bestTheta = null;
167
        $bestScore = 0.0;
168
        $this->costValues = [];
169
170
        while ($this->maxIterations > $currIter++) {
171
            $theta = $this->theta;
172
173
            // Update the guess
174
            $cost = $this->updateTheta();
175
176
            // Save the best theta in the "pocket" so that
177
            // any future set of theta worse than this will be disregarded
178
            if ($bestTheta == null || $cost <= $bestScore) {
179
                $bestTheta = $theta;
180
                $bestScore = $cost;
181
            }
182
183
            // Add the cost value for this iteration to the list
184
            $this->costValues[] = $cost;
185
186
            // Check for early stop
187
            if ($this->enableEarlyStop && $this->earlyStop($theta)) {
188
                break;
189
            }
190
        }
191
192
        $this->clear();
193
194
        // Solution in the pocket is better than or equal to the last state
195
        // so, we use this solution
196
        return $this->theta = (array) $bestTheta;
197
    }
198
199
    /**
200
     * Returns the list of cost values for each iteration executed in
201
     * last run of the optimization
202
     */
203
    public function getCostValues(): array
204
    {
205
        return $this->costValues;
206
    }
207
208
    protected function updateTheta(): float
209
    {
210
        $jValue = 0.0;
211
        $theta = $this->theta;
212
213
        foreach ($this->samples as $index => $sample) {
214
            $target = $this->targets[$index];
215
216
            $result = ($this->gradientCb)($theta, $sample, $target);
217
218
            [$error, $gradient, $penalty] = array_pad($result, 3, 0);
0 ignored issues
show
Bug introduced by
The variable $error does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
Bug introduced by
The variable $gradient does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
Bug introduced by
The variable $penalty does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
219
220
            // Update bias
221
            $this->theta[0] -= $this->learningRate * $gradient;
222
223
            // Update other values
224
            for ($i = 1; $i <= $this->dimensions; ++$i) {
225
                $this->theta[$i] -= $this->learningRate *
226
                    ($gradient * $sample[$i - 1] + $penalty * $this->theta[$i]);
227
            }
228
229
            // Sum error rate
230
            $jValue += $error;
231
        }
232
233
        return $jValue / count($this->samples);
234
    }
235
236
    /**
237
     * Checks if the optimization is not effective enough and can be stopped
238
     * in case large enough changes in the solution do not happen
239
     */
240
    protected function earlyStop(array $oldTheta): bool
241
    {
242
        // Check for early stop: No change larger than threshold (default 1e-5)
243
        $diff = array_map(
244
            function ($w1, $w2) {
245
                return abs($w1 - $w2) > $this->threshold ? 1 : 0;
246
            },
247
            $oldTheta,
248
            $this->theta
249
        );
250
251
        if (array_sum($diff) == 0) {
252
            return true;
253
        }
254
255
        // Check if the last two cost values are almost the same
256
        $costs = array_slice($this->costValues, -2);
257
        if (count($costs) == 2 && abs($costs[1] - $costs[0]) < $this->threshold) {
258
            return true;
259
        }
260
261
        return false;
262
    }
263
264
    /**
265
     * Clears the optimizer internal vars after the optimization process.
266
     */
267
    protected function clear(): void
268
    {
269
        $this->samples = [];
270
        $this->targets = [];
271
        $this->gradientCb = null;
272
    }
273
}
274