Passed
Push — master ( fbbe5c...a34811 )
by Arkadiusz
07:00
created

Phpml/Classification/Linear/LogisticRegression.php (1 issue)

Checks property assignments for possibly missing type casts

Bug Documentation Minor

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Closure;
8
use Exception;
9
use Phpml\Helper\Optimizer\ConjugateGradient;
10
11
class LogisticRegression extends Adaline
12
{
13
    /**
14
     * Batch training: Gradient descent algorithm (default)
15
     */
16
    public const BATCH_TRAINING = 1;
17
18
    /**
19
     * Online training: Stochastic gradient descent learning
20
     */
21
    public const ONLINE_TRAINING = 2;
22
23
    /**
24
     * Conjugate Batch: Conjugate Gradient algorithm
25
     */
26
    public const CONJUGATE_GRAD_TRAINING = 3;
27
28
    /**
29
     * Cost function to optimize: 'log' and 'sse' are supported <br>
30
     *  - 'log' : log likelihood <br>
31
     *  - 'sse' : sum of squared errors <br>
32
     *
33
     * @var string
34
     */
35
    protected $costFunction = 'log';
36
37
    /**
38
     * Regularization term: only 'L2' is supported
39
     *
40
     * @var string
41
     */
42
    protected $penalty = 'L2';
43
44
    /**
45
     * Lambda (λ) parameter of regularization term. If λ is set to 0, then
46
     * regularization term is cancelled.
47
     *
48
     * @var float
49
     */
50
    protected $lambda = 0.5;
51
52
    /**
53
     * Initalize a Logistic Regression classifier with maximum number of iterations
54
     * and learning rule to be applied <br>
55
     *
56
     * Maximum number of iterations can be an integer value greater than 0 <br>
57
     * If normalizeInputs is set to true, then every input given to the algorithm will be standardized
58
     * by use of standard deviation and mean calculation <br>
59
     *
60
     * Cost function can be 'log' for log-likelihood and 'sse' for sum of squared errors <br>
61
     *
62
     * Penalty (Regularization term) can be 'L2' or empty string to cancel penalty term
63
     *
64
     * @throws \Exception
65
     */
66
    public function __construct(
67
        int $maxIterations = 500,
68
        bool $normalizeInputs = true,
69
        int $trainingType = self::CONJUGATE_GRAD_TRAINING,
70
        string $cost = 'log',
71
        string $penalty = 'L2'
72
    ) {
73
        $trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING);
74
        if (!in_array($trainingType, $trainingTypes)) {
75
            throw new Exception('Logistic regression can only be trained with '.
76
                'batch (gradient descent), online (stochastic gradient descent) '.
77
                'or conjugate batch (conjugate gradients) algorithms');
78
        }
79
80
        if (!in_array($cost, ['log', 'sse'])) {
81
            throw new Exception("Logistic regression cost function can be one of the following: \n".
82
                "'log' for log-likelihood and 'sse' for sum of squared errors");
83
        }
84
85
        if ($penalty != '' && strtoupper($penalty) !== 'L2') {
86
            throw new Exception("Logistic regression supports only 'L2' regularization");
87
        }
88
89
        $this->learningRate = 0.001;
90
91
        parent::__construct($this->learningRate, $maxIterations, $normalizeInputs);
92
93
        $this->trainingType = $trainingType;
0 ignored issues
show
Documentation Bug introduced by
The property $trainingType was declared of type string, but $trainingType is of type integer. Maybe add a type cast?

This check looks for assignments to scalar types that may be of the wrong type.

To ensure the code behaves as expected, it may be a good idea to add an explicit type cast.

$answer = 42;

$correct = false;

$correct = (bool) $answer;
Loading history...
94
        $this->costFunction = $cost;
95
        $this->penalty = $penalty;
96
    }
97
98
    /**
99
     * Sets the learning rate if gradient descent algorithm is
100
     * selected for training
101
     */
102
    public function setLearningRate(float $learningRate): void
103
    {
104
        $this->learningRate = $learningRate;
105
    }
106
107
    /**
108
     * Lambda (λ) parameter of regularization term. If 0 is given,
109
     * then the regularization term is cancelled
110
     */
111
    public function setLambda(float $lambda): void
112
    {
113
        $this->lambda = $lambda;
114
    }
115
116
    /**
117
     * Adapts the weights with respect to given samples and targets
118
     * by use of selected solver
119
     *
120
     * @throws \Exception
121
     */
122
    protected function runTraining(array $samples, array $targets)
123
    {
124
        $callback = $this->getCostFunction();
125
126
        switch ($this->trainingType) {
127
            case self::BATCH_TRAINING:
128
                return $this->runGradientDescent($samples, $targets, $callback, true);
129
130
            case self::ONLINE_TRAINING:
131
                return $this->runGradientDescent($samples, $targets, $callback, false);
132
133
            case self::CONJUGATE_GRAD_TRAINING:
134
                return $this->runConjugateGradient($samples, $targets, $callback);
135
136
            default:
137
                throw new Exception('Logistic regression has invalid training type: %s.', $this->trainingType);
138
        }
139
    }
140
141
    /**
142
     * Executes Conjugate Gradient method to optimize the weights of the LogReg model
143
     */
144
    protected function runConjugateGradient(array $samples, array $targets, Closure $gradientFunc): void
145
    {
146
        if (empty($this->optimizer)) {
147
            $this->optimizer = (new ConjugateGradient($this->featureCount))
148
                ->setMaxIterations($this->maxIterations);
149
        }
150
151
        $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
152
        $this->costValues = $this->optimizer->getCostValues();
153
    }
154
155
    /**
156
     * Returns the appropriate callback function for the selected cost function
157
     *
158
     * @throws \Exception
159
     */
160
    protected function getCostFunction(): Closure
161
    {
162
        $penalty = 0;
163
        if ($this->penalty == 'L2') {
164
            $penalty = $this->lambda;
165
        }
166
167
        switch ($this->costFunction) {
168
            case 'log':
169
                /*
170
                 * Negative of Log-likelihood cost function to be minimized:
171
                 *		J(x) = ∑( - y . log(h(x)) - (1 - y) . log(1 - h(x)))
172
                 *
173
                 * If regularization term is given, then it will be added to the cost:
174
                 *		for L2 : J(x) = J(x) +  λ/m . w
175
                 *
176
                 * The gradient of the cost function to be used with gradient descent:
177
                 *		∇J(x) = -(y - h(x)) = (h(x) - y)
178
                 */
179
                $callback = function ($weights, $sample, $y) use ($penalty) {
180
                    $this->weights = $weights;
181
                    $hX = $this->output($sample);
182
183
                    // In cases where $hX = 1 or $hX = 0, the log-likelihood
184
                    // value will give a NaN, so we fix these values
185
                    if ($hX == 1) {
186
                        $hX = 1 - 1e-10;
187
                    }
188
189
                    if ($hX == 0) {
190
                        $hX = 1e-10;
191
                    }
192
193
                    $y = $y < 0 ? 0 : 1;
194
195
                    $error = -$y * log($hX) - (1 - $y) * log(1 - $hX);
196
                    $gradient = $hX - $y;
197
198
                    return [$error, $gradient, $penalty];
199
                };
200
201
                return $callback;
202
203
            case 'sse':
204
                /*
205
                 * Sum of squared errors or least squared errors cost function:
206
                 *		J(x) = ∑ (y - h(x))^2
207
                 *
208
                 * If regularization term is given, then it will be added to the cost:
209
                 *		for L2 : J(x) = J(x) +  λ/m . w
210
                 *
211
                 * The gradient of the cost function:
212
                 *		∇J(x) = -(h(x) - y) . h(x) . (1 - h(x))
213
                 */
214
                $callback = function ($weights, $sample, $y) use ($penalty) {
215
                    $this->weights = $weights;
216
                    $hX = $this->output($sample);
217
218
                    $y = $y < 0 ? 0 : 1;
219
220
                    $error = ($y - $hX) ** 2;
221
                    $gradient = -($y - $hX) * $hX * (1 - $hX);
222
223
                    return [$error, $gradient, $penalty];
224
                };
225
226
                return $callback;
227
228
            default:
229
                throw new Exception(sprintf('Logistic regression has invalid cost function: %s.', $this->costFunction));
230
        }
231
    }
232
233
    /**
234
     * Returns the output of the network, a float value between 0.0 and 1.0
235
     */
236
    protected function output(array $sample): float
237
    {
238
        $sum = parent::output($sample);
239
240
        return 1.0 / (1.0 + exp(-$sum));
241
    }
242
243
    /**
244
     * Returns the class value (either -1 or 1) for the given input
245
     */
246
    protected function outputClass(array $sample): int
247
    {
248
        $output = $this->output($sample);
249
250
        if ($output > 0.5) {
251
            return 1;
252
        }
253
254
        return -1;
255
    }
256
257
    /**
258
     * Returns the probability of the sample of belonging to the given label.
259
     *
260
     * The probability is simply taken as the distance of the sample
261
     * to the decision plane.
262
     *
263
     * @param mixed $label
264
     */
265
    protected function predictProbability(array $sample, $label): float
266
    {
267
        $sample = $this->checkNormalizedSample($sample);
268
        $probability = $this->output($sample);
269
270
        if (array_search($label, $this->labels, true) > 0) {
271
            return $probability;
272
        }
273
274
        return 1 - $probability;
275
    }
276
}
277