Passed
Push — master ( fbbe5c...a34811 )
by Arkadiusz
07:00
created

src/Phpml/Helper/Optimizer/StochasticGD.php (1 issue)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper\Optimizer;
6
7
use Closure;
8
9
/**
10
 * Stochastic Gradient Descent optimization method
11
 * to find a solution for the equation A.ϴ = y where
12
 *  A (samples) and y (targets) are known and ϴ is unknown.
13
 */
14
class StochasticGD extends Optimizer
15
{
16
    /**
17
     * A (samples)
18
     *
19
     * @var array
20
     */
21
    protected $samples = [];
22
23
    /**
24
     * y (targets)
25
     *
26
     * @var array
27
     */
28
    protected $targets = [];
29
30
    /**
31
     * Callback function to get the gradient and cost value
32
     * for a specific set of theta (ϴ) and a pair of sample & target
33
     *
34
     * @var \Closure
35
     */
36
    protected $gradientCb = null;
37
38
    /**
39
     * Maximum number of iterations used to train the model
40
     *
41
     * @var int
42
     */
43
    protected $maxIterations = 1000;
44
45
    /**
46
     * Learning rate is used to control the speed of the optimization.<br>
47
     *
48
     * Larger values of lr may overshoot the optimum or even cause divergence
49
     * while small values slows down the convergence and increases the time
50
     * required for the training
51
     *
52
     * @var float
53
     */
54
    protected $learningRate = 0.001;
55
56
    /**
57
     * Minimum amount of change in the weights and error values
58
     * between iterations that needs to be obtained to continue the training
59
     *
60
     * @var float
61
     */
62
    protected $threshold = 1e-4;
63
64
    /**
65
     * Enable/Disable early stopping by checking the weight & cost values
66
     * to see whether they changed large enough to continue the optimization
67
     *
68
     * @var bool
69
     */
70
    protected $enableEarlyStop = true;
71
72
    /**
73
     * List of values obtained by evaluating the cost function at each iteration
74
     * of the algorithm
75
     *
76
     * @var array
77
     */
78
    protected $costValues = [];
79
80
    /**
81
     * Initializes the SGD optimizer for the given number of dimensions
82
     */
83
    public function __construct(int $dimensions)
84
    {
85
        // Add one more dimension for the bias
86
        parent::__construct($dimensions + 1);
87
88
        $this->dimensions = $dimensions;
89
    }
90
91
    /**
92
     * Sets minimum value for the change in the theta values
93
     * between iterations to continue the iterations.<br>
94
     *
95
     * If change in the theta is less than given value then the
96
     * algorithm will stop training
97
     *
98
     * @return $this
99
     */
100
    public function setChangeThreshold(float $threshold = 1e-5)
101
    {
102
        $this->threshold = $threshold;
103
104
        return $this;
105
    }
106
107
    /**
108
     * Enable/Disable early stopping by checking at each iteration
109
     * whether changes in theta or cost value are not large enough
110
     *
111
     * @return $this
112
     */
113
    public function setEarlyStop(bool $enable = true)
114
    {
115
        $this->enableEarlyStop = $enable;
116
117
        return $this;
118
    }
119
120
    /**
121
     * @return $this
122
     */
123
    public function setLearningRate(float $learningRate)
124
    {
125
        $this->learningRate = $learningRate;
126
127
        return $this;
128
    }
129
130
    /**
131
     * @return $this
132
     */
133
    public function setMaxIterations(int $maxIterations)
134
    {
135
        $this->maxIterations = $maxIterations;
136
137
        return $this;
138
    }
139
140
    /**
141
     * Optimization procedure finds the unknow variables for the equation A.ϴ = y
142
     * for the given samples (A) and targets (y).<br>
143
     *
144
     * The cost function to minimize and the gradient of the function are to be
145
     * handled by the callback function provided as the third parameter of the method.
146
     */
147
    public function runOptimization(array $samples, array $targets, Closure $gradientCb): array
148
    {
149
        $this->samples = $samples;
150
        $this->targets = $targets;
151
        $this->gradientCb = $gradientCb;
152
153
        $currIter = 0;
154
        $bestTheta = null;
155
        $bestScore = 0.0;
156
        $this->costValues = [];
157
158
        while ($this->maxIterations > $currIter++) {
159
            $theta = $this->theta;
160
161
            // Update the guess
162
            $cost = $this->updateTheta();
163
164
            // Save the best theta in the "pocket" so that
165
            // any future set of theta worse than this will be disregarded
166
            if ($bestTheta == null || $cost <= $bestScore) {
167
                $bestTheta = $theta;
168
                $bestScore = $cost;
169
            }
170
171
            // Add the cost value for this iteration to the list
172
            $this->costValues[] = $cost;
173
174
            // Check for early stop
175
            if ($this->enableEarlyStop && $this->earlyStop($theta)) {
176
                break;
177
            }
178
        }
179
180
        $this->clear();
181
182
        // Solution in the pocket is better than or equal to the last state
183
        // so, we use this solution
184
        return $this->theta = $bestTheta;
0 ignored issues
show
Documentation Bug introduced by Mustafa Karabulut
It seems like $bestTheta can be null. However, the property $theta is declared as array. Maybe change the type of the property to array|null or add a type check?

Our type inference engine has found an assignment of a scalar value (like a string, an integer or null) to a property which is an array.

Either this assignment is in error or the assigned type should be added to the documentation/type hint for that property.

To type hint that a parameter can be either an array or null, you can set a type hint of array and a default value of null. The PHP interpreter will then accept both an array or null for that parameter.

function aContainsB(array $needle = null, array  $haystack) {
    if (!$needle) {
        return false;
    }

    return array_intersect($haystack, $needle) == $haystack;
}

The function can be called with either null or an array for the parameter $needle but will only accept an array as $haystack.

Loading history...
185
    }
186
187
    /**
188
     * Returns the list of cost values for each iteration executed in
189
     * last run of the optimization
190
     */
191
    public function getCostValues(): array
192
    {
193
        return $this->costValues;
194
    }
195
196
    protected function updateTheta(): float
197
    {
198
        $jValue = 0.0;
199
        $theta = $this->theta;
200
201
        foreach ($this->samples as $index => $sample) {
202
            $target = $this->targets[$index];
203
204
            $result = ($this->gradientCb)($theta, $sample, $target);
205
206
            [$error, $gradient, $penalty] = array_pad($result, 3, 0);
207
208
            // Update bias
209
            $this->theta[0] -= $this->learningRate * $gradient;
210
211
            // Update other values
212
            for ($i = 1; $i <= $this->dimensions; ++$i) {
213
                $this->theta[$i] -= $this->learningRate *
214
                    ($gradient * $sample[$i - 1] + $penalty * $this->theta[$i]);
215
            }
216
217
            // Sum error rate
218
            $jValue += $error;
219
        }
220
221
        return $jValue / count($this->samples);
222
    }
223
224
    /**
225
     * Checks if the optimization is not effective enough and can be stopped
226
     * in case large enough changes in the solution do not happen
227
     */
228
    protected function earlyStop(array $oldTheta): bool
229
    {
230
        // Check for early stop: No change larger than threshold (default 1e-5)
231
        $diff = array_map(
232
            function ($w1, $w2) {
233
                return abs($w1 - $w2) > $this->threshold ? 1 : 0;
234
            },
235
            $oldTheta,
236
            $this->theta
237
        );
238
239
        if (array_sum($diff) == 0) {
240
            return true;
241
        }
242
243
        // Check if the last two cost values are almost the same
244
        $costs = array_slice($this->costValues, -2);
245
        if (count($costs) == 2 && abs($costs[1] - $costs[0]) < $this->threshold) {
246
            return true;
247
        }
248
249
        return false;
250
    }
251
252
    /**
253
     * Clears the optimizer internal vars after the optimization process.
254
     */
255
    protected function clear(): void
256
    {
257
        $this->samples = [];
258
        $this->targets = [];
259
        $this->gradientCb = null;
260
    }
261
}
262