Passed
Push — master ( 331d4b...653c7c )
by Arkadiusz
02:19
created

src/Phpml/Helper/Optimizer/StochasticGD.php (3 issues)

Labels
Severity

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper\Optimizer;
6
7
/**
8
 * Stochastic Gradient Descent optimization method
9
 * to find a solution for the equation A.ϴ = y where
10
 *  A (samples) and y (targets) are known and ϴ is unknown.
11
 */
12
class StochasticGD extends Optimizer
13
{
14
    /**
15
     * A (samples)
16
     *
17
     * @var array
18
     */
19
    protected $samples = [];
20
21
    /**
22
     * y (targets)
23
     *
24
     * @var array
25
     */
26
    protected $targets = [];
27
28
    /**
29
     * Callback function to get the gradient and cost value
30
     * for a specific set of theta (ϴ) and a pair of sample & target
31
     *
32
     * @var \Closure
33
     */
34
    protected $gradientCb = null;
35
36
    /**
37
     * Maximum number of iterations used to train the model
38
     *
39
     * @var int
40
     */
41
    protected $maxIterations = 1000;
42
43
    /**
44
     * Learning rate is used to control the speed of the optimization.<br>
45
     *
46
     * Larger values of lr may overshoot the optimum or even cause divergence
47
     * while small values slows down the convergence and increases the time
48
     * required for the training
49
     *
50
     * @var float
51
     */
52
    protected $learningRate = 0.001;
53
54
    /**
55
     * Minimum amount of change in the weights and error values
56
     * between iterations that needs to be obtained to continue the training
57
     *
58
     * @var float
59
     */
60
    protected $threshold = 1e-4;
61
62
    /**
63
     * Enable/Disable early stopping by checking the weight & cost values
64
     * to see whether they changed large enough to continue the optimization
65
     *
66
     * @var bool
67
     */
68
    protected $enableEarlyStop = true;
69
    /**
70
     * List of values obtained by evaluating the cost function at each iteration
71
     * of the algorithm
72
     *
73
     * @var array
74
     */
75
    protected $costValues = [];
76
77
    /**
78
     * Initializes the SGD optimizer for the given number of dimensions
79
     */
80
    public function __construct(int $dimensions)
81
    {
82
        // Add one more dimension for the bias
83
        parent::__construct($dimensions + 1);
84
85
        $this->dimensions = $dimensions;
86
    }
87
88
    /**
89
     * Sets minimum value for the change in the theta values
90
     * between iterations to continue the iterations.<br>
91
     *
92
     * If change in the theta is less than given value then the
93
     * algorithm will stop training
94
     *
95
     * @return $this
96
     */
97
    public function setChangeThreshold(float $threshold = 1e-5)
98
    {
99
        $this->threshold = $threshold;
100
101
        return $this;
102
    }
103
104
    /**
105
     * Enable/Disable early stopping by checking at each iteration
106
     * whether changes in theta or cost value are not large enough
107
     *
108
     * @return $this
109
     */
110
    public function setEarlyStop(bool $enable = true)
111
    {
112
        $this->enableEarlyStop = $enable;
113
114
        return $this;
115
    }
116
117
    /**
118
     * @return $this
119
     */
120
    public function setLearningRate(float $learningRate)
121
    {
122
        $this->learningRate = $learningRate;
123
124
        return $this;
125
    }
126
127
    /**
128
     * @return $this
129
     */
130
    public function setMaxIterations(int $maxIterations)
131
    {
132
        $this->maxIterations = $maxIterations;
133
134
        return $this;
135
    }
136
137
    /**
138
     * Optimization procedure finds the unknow variables for the equation A.ϴ = y
139
     * for the given samples (A) and targets (y).<br>
140
     *
141
     * The cost function to minimize and the gradient of the function are to be
142
     * handled by the callback function provided as the third parameter of the method.
143
     */
144
    public function runOptimization(array $samples, array $targets, \Closure $gradientCb) : array
145
    {
146
        $this->samples = $samples;
147
        $this->targets = $targets;
148
        $this->gradientCb = $gradientCb;
149
150
        $currIter = 0;
151
        $bestTheta = null;
152
        $bestScore = 0.0;
153
        $this->costValues = [];
154
155
        while ($this->maxIterations > $currIter++) {
156
            $theta = $this->theta;
157
158
            // Update the guess
159
            $cost = $this->updateTheta();
160
161
            // Save the best theta in the "pocket" so that
162
            // any future set of theta worse than this will be disregarded
163
            if ($bestTheta == null || $cost <= $bestScore) {
164
                $bestTheta = $theta;
165
                $bestScore = $cost;
166
            }
167
168
            // Add the cost value for this iteration to the list
169
            $this->costValues[] = $cost;
170
171
            // Check for early stop
172
            if ($this->enableEarlyStop && $this->earlyStop($theta)) {
173
                break;
174
            }
175
        }
176
177
        $this->clear();
178
179
        // Solution in the pocket is better than or equal to the last state
180
        // so, we use this solution
181
        return $this->theta = $bestTheta;
182
    }
183
184
    protected function updateTheta() : float
185
    {
186
        $jValue = 0.0;
187
        $theta = $this->theta;
188
189
        foreach ($this->samples as $index => $sample) {
190
            $target = $this->targets[$index];
191
192
            $result = ($this->gradientCb)($theta, $sample, $target);
193
194
            [$error, $gradient, $penalty] = array_pad($result, 3, 0);
0 ignored issues
show
The variable $error does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
The variable $gradient does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
The variable $penalty does not exist. Did you forget to declare it?

This check marks access to variables or properties that have not been declared yet. While PHP has no explicit notion of declaring a variable, accessing it before a value is assigned to it is most likely a bug.

Loading history...
195
196
            // Update bias
197
            $this->theta[0] -= $this->learningRate * $gradient;
198
199
            // Update other values
200
            for ($i = 1; $i <= $this->dimensions; ++$i) {
201
                $this->theta[$i] -= $this->learningRate *
202
                    ($gradient * $sample[$i - 1] + $penalty * $this->theta[$i]);
203
            }
204
205
            // Sum error rate
206
            $jValue += $error;
207
        }
208
209
        return $jValue / count($this->samples);
210
    }
211
212
    /**
213
     * Checks if the optimization is not effective enough and can be stopped
214
     * in case large enough changes in the solution do not happen
215
     */
216
    protected function earlyStop(array $oldTheta): bool
217
    {
218
        // Check for early stop: No change larger than threshold (default 1e-5)
219
        $diff = array_map(
220
            function ($w1, $w2) {
221
                return abs($w1 - $w2) > $this->threshold ? 1 : 0;
222
            },
223
            $oldTheta,
224
            $this->theta
225
        );
226
227
        if (array_sum($diff) == 0) {
228
            return true;
229
        }
230
231
        // Check if the last two cost values are almost the same
232
        $costs = array_slice($this->costValues, -2);
233
        if (count($costs) == 2 && abs($costs[1] - $costs[0]) < $this->threshold) {
234
            return true;
235
        }
236
237
        return false;
238
    }
239
240
    /**
241
     * Returns the list of cost values for each iteration executed in
242
     * last run of the optimization
243
     */
244
    public function getCostValues() : array
245
    {
246
        return $this->costValues;
247
    }
248
249
    /**
250
     * Clears the optimizer internal vars after the optimization process.
251
     */
252
    protected function clear(): void
253
    {
254
        $this->samples = [];
255
        $this->targets = [];
256
        $this->gradientCb = null;
257
    }
258
}
259