Passed
Push — master ( c44f3b...492344 )
by Arkadiusz
02:48
created

StochasticGD::setMaxIterations()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 6
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 6
rs 9.4285
c 0
b 0
f 0
cc 1
eloc 3
nc 1
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Helper\Optimizer;
6
7
/**
8
 * Stochastic Gradient Descent optimization method
9
 * to find a solution for the equation A.ϴ = y where
10
 *  A (samples) and y (targets) are known and ϴ is unknown.
11
 */
12
class StochasticGD extends Optimizer
13
{
14
    /**
15
     * A (samples)
16
     *
17
     * @var array
18
     */
19
    protected $samples;
20
21
    /**
22
     * y (targets)
23
     *
24
     * @var array
25
     */
26
    protected $targets;
27
28
    /**
29
     * Callback function to get the gradient and cost value
30
     * for a specific set of theta (ϴ) and a pair of sample & target
31
     *
32
     * @var \Closure
33
     */
34
    protected $gradientCb;
35
36
    /**
37
     * Maximum number of iterations used to train the model
38
     *
39
     * @var int
40
     */
41
    protected $maxIterations = 1000;
42
43
    /**
44
     * Learning rate is used to control the speed of the optimization.<br>
45
     *
46
     * Larger values of lr may overshoot the optimum or even cause divergence
47
     * while small values slows down the convergence and increases the time
48
     * required for the training
49
     *
50
     * @var float
51
     */
52
    protected $learningRate = 0.001;
53
54
    /**
55
     * Minimum amount of change in the weights and error values
56
     * between iterations that needs to be obtained to continue the training
57
     *
58
     * @var float
59
     */
60
    protected $threshold = 1e-4;
61
62
    /**
63
     * Enable/Disable early stopping by checking the weight & cost values
64
     * to see whether they changed large enough to continue the optimization
65
     *
66
     * @var bool
67
     */
68
    protected $enableEarlyStop = true;
69
    /**
70
     * List of values obtained by evaluating the cost function at each iteration
71
     * of the algorithm
72
     *
73
     * @var array
74
     */
75
    protected $costValues= [];
76
77
    /**
78
     * Initializes the SGD optimizer for the given number of dimensions
79
     *
80
     * @param int $dimensions
81
     */
82
    public function __construct(int $dimensions)
83
    {
84
        // Add one more dimension for the bias
85
        parent::__construct($dimensions + 1);
86
87
        $this->dimensions = $dimensions;
88
    }
89
90
    /**
91
     * Sets minimum value for the change in the theta values
92
     * between iterations to continue the iterations.<br>
93
     *
94
     * If change in the theta is less than given value then the
95
     * algorithm will stop training
96
     *
97
     * @param float $threshold
98
     *
99
     * @return $this
100
     */
101
    public function setChangeThreshold(float $threshold = 1e-5)
102
    {
103
        $this->threshold = $threshold;
104
105
        return $this;
106
    }
107
108
    /**
109
     * Enable/Disable early stopping by checking at each iteration
110
     * whether changes in theta or cost value are not large enough
111
     *
112
     * @param bool $enable
113
     *
114
     * @return $this
115
     */
116
    public function setEarlyStop(bool $enable = true)
117
    {
118
        $this->enableEarlyStop = $enable;
119
120
        return $this;
121
    }
122
123
    /**
124
     * @param float $learningRate
125
     *
126
     * @return $this
127
     */
128
    public function setLearningRate(float $learningRate)
129
    {
130
        $this->learningRate = $learningRate;
131
132
        return $this;
133
    }
134
135
    /**
136
     * @param int $maxIterations
137
     *
138
     * @return $this
139
     */
140
    public function setMaxIterations(int $maxIterations)
141
    {
142
        $this->maxIterations = $maxIterations;
143
144
        return $this;
145
    }
146
147
    /**
148
     * Optimization procedure finds the unknow variables for the equation A.ϴ = y
149
     * for the given samples (A) and targets (y).<br>
150
     *
151
     * The cost function to minimize and the gradient of the function are to be
152
     * handled by the callback function provided as the third parameter of the method.
153
     *
154
     * @param array $samples
155
     * @param array $targets
156
     * @param \Closure $gradientCb
157
     *
158
     * @return array
159
     */
160
    public function runOptimization(array $samples, array $targets, \Closure $gradientCb)
161
    {
162
        $this->samples = $samples;
163
        $this->targets = $targets;
164
        $this->gradientCb = $gradientCb;
165
166
        $currIter = 0;
167
        $bestTheta = null;
168
        $bestScore = 0.0;
169
        $bestWeightIter = 0;
0 ignored issues
show
Unused Code introduced by
$bestWeightIter is not used, you could remove the assignment.

This check looks for variable assignements that are either overwritten by other assignments or where the variable is not used subsequently.

$myVar = 'Value';
$higher = false;

if (rand(1, 6) > 3) {
    $higher = true;
} else {
    $higher = false;
}

Both the $myVar assignment in line 1 and the $higher assignment in line 2 are dead. The first because $myVar is never used and the second because $higher is always overwritten for every possible time line.

Loading history...
170
        $this->costValues = [];
171
172
        while ($this->maxIterations > $currIter++) {
173
            $theta = $this->theta;
174
175
            // Update the guess
176
            $cost = $this->updateTheta();
177
178
            // Save the best theta in the "pocket" so that
179
            // any future set of theta worse than this will be disregarded
180
            if ($bestTheta == null || $cost <= $bestScore) {
181
                $bestTheta = $theta;
182
                $bestScore = $cost;
183
                $bestWeightIter = $currIter;
0 ignored issues
show
Unused Code introduced by
$bestWeightIter is not used, you could remove the assignment.

This check looks for variable assignements that are either overwritten by other assignments or where the variable is not used subsequently.

$myVar = 'Value';
$higher = false;

if (rand(1, 6) > 3) {
    $higher = true;
} else {
    $higher = false;
}

Both the $myVar assignment in line 1 and the $higher assignment in line 2 are dead. The first because $myVar is never used and the second because $higher is always overwritten for every possible time line.

Loading history...
184
            }
185
186
            // Add the cost value for this iteration to the list
187
            $this->costValues[] = $cost;
188
189
            // Check for early stop
190
            if ($this->enableEarlyStop && $this->earlyStop($theta)) {
191
                break;
192
            }
193
        }
194
195
        // Solution in the pocket is better than or equal to the last state
196
        // so, we use this solution
197
        return $this->theta = $bestTheta;
0 ignored issues
show
Documentation Bug introduced by
It seems like $bestTheta can be null. However, the property $theta is declared as array. Maybe change the type of the property to array|null or add a type check?

Our type inference engine has found an assignment of a scalar value (like a string, an integer or null) to a property which is an array.

Either this assignment is in error or the assigned type should be added to the documentation/type hint for that property.

To type hint that a parameter can be either an array or null, you can set a type hint of array and a default value of null. The PHP interpreter will then accept both an array or null for that parameter.

function aContainsB(array $needle = null, array  $haystack) {
    if (!$needle) {
        return false;
    }

    return array_intersect($haystack, $needle) == $haystack;
}

The function can be called with either null or an array for the parameter $needle but will only accept an array as $haystack.

Loading history...
198
    }
199
200
    /**
201
     * @return float
202
     */
203
    protected function updateTheta()
204
    {
205
        $jValue = 0.0;
206
        $theta = $this->theta;
207
208
        foreach ($this->samples as $index => $sample) {
209
            $target = $this->targets[$index];
210
211
            $result = ($this->gradientCb)($theta, $sample, $target);
212
213
            list($error, $gradient, $penalty) = array_pad($result, 3, 0);
214
215
            // Update bias
216
            $this->theta[0] -= $this->learningRate * $gradient;
217
218
            // Update other values
219
            for ($i=1; $i <= $this->dimensions; $i++) {
220
                $this->theta[$i] -= $this->learningRate *
221
                    ($gradient * $sample[$i - 1] + $penalty * $this->theta[$i]);
222
            }
223
224
            // Sum error rate
225
            $jValue += $error;
226
        }
227
228
        return $jValue / count($this->samples);
229
    }
230
231
    /**
232
     * Checks if the optimization is not effective enough and can be stopped
233
     * in case large enough changes in the solution do not happen
234
     *
235
     * @param array $oldTheta
236
     *
237
     * @return boolean
238
     */
239
    protected function earlyStop($oldTheta)
240
    {
241
        // Check for early stop: No change larger than threshold (default 1e-5)
242
        $diff = array_map(
243
            function ($w1, $w2) {
244
                return abs($w1 - $w2) > $this->threshold ? 1 : 0;
245
            },
246
            $oldTheta, $this->theta);
247
248
        if (array_sum($diff) == 0) {
249
            return true;
250
        }
251
252
        // Check if the last two cost values are almost the same
253
        $costs = array_slice($this->costValues, -2);
254
        if (count($costs) == 2 && abs($costs[1] - $costs[0]) < $this->threshold) {
255
            return true;
256
        }
257
258
        return false;
259
    }
260
261
    /**
262
     * Returns the list of cost values for each iteration executed in
263
     * last run of the optimization
264
     *
265
     * @return array
266
     */
267
    public function getCostValues()
268
    {
269
        return $this->costValues;
270
    }
271
}
272