Passed
Push — master ( fbbe5c...a34811 )
by Arkadiusz
07:00
created

Phpml/Classification/Linear/LogisticRegression.php (1 issue)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Closure;
8
use Exception;
9
use Phpml\Helper\Optimizer\ConjugateGradient;
10
11
class LogisticRegression extends Adaline
12
{
13
    /**
14
     * Batch training: Gradient descent algorithm (default)
15
     */
16
    public const BATCH_TRAINING = 1;
17
18
    /**
19
     * Online training: Stochastic gradient descent learning
20
     */
21
    public const ONLINE_TRAINING = 2;
22
23
    /**
24
     * Conjugate Batch: Conjugate Gradient algorithm
25
     */
26
    public const CONJUGATE_GRAD_TRAINING = 3;
27
28
    /**
29
     * Cost function to optimize: 'log' and 'sse' are supported <br>
30
     *  - 'log' : log likelihood <br>
31
     *  - 'sse' : sum of squared errors <br>
32
     *
33
     * @var string
34
     */
35
    protected $costFunction = 'log';
36
37
    /**
38
     * Regularization term: only 'L2' is supported
39
     *
40
     * @var string
41
     */
42
    protected $penalty = 'L2';
43
44
    /**
45
     * Lambda (λ) parameter of regularization term. If λ is set to 0, then
46
     * regularization term is cancelled.
47
     *
48
     * @var float
49
     */
50
    protected $lambda = 0.5;
51
52
    /**
53
     * Initalize a Logistic Regression classifier with maximum number of iterations
54
     * and learning rule to be applied <br>
55
     *
56
     * Maximum number of iterations can be an integer value greater than 0 <br>
57
     * If normalizeInputs is set to true, then every input given to the algorithm will be standardized
58
     * by use of standard deviation and mean calculation <br>
59
     *
60
     * Cost function can be 'log' for log-likelihood and 'sse' for sum of squared errors <br>
61
     *
62
     * Penalty (Regularization term) can be 'L2' or empty string to cancel penalty term
63
     *
64
     * @throws \Exception
65
     */
66
    public function __construct(
67
        int $maxIterations = 500,
68
        bool $normalizeInputs = true,
69
        int $trainingType = self::CONJUGATE_GRAD_TRAINING,
70
        string $cost = 'log',
71
        string $penalty = 'L2'
72
    ) {
73
        $trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING);
74
        if (!in_array($trainingType, $trainingTypes)) {
75
            throw new Exception('Logistic regression can only be trained with '.
76
                'batch (gradient descent), online (stochastic gradient descent) '.
77
                'or conjugate batch (conjugate gradients) algorithms');
78
        }
79
80
        if (!in_array($cost, ['log', 'sse'])) {
81
            throw new Exception("Logistic regression cost function can be one of the following: \n".
82
                "'log' for log-likelihood and 'sse' for sum of squared errors");
83
        }
84
85
        if ($penalty != '' && strtoupper($penalty) !== 'L2') {
86
            throw new Exception("Logistic regression supports only 'L2' regularization");
87
        }
88
89
        $this->learningRate = 0.001;
90
91
        parent::__construct($this->learningRate, $maxIterations, $normalizeInputs);
92
93
        $this->trainingType = $trainingType;
0 ignored issues
show
Documentation Bug introduced by
The property $trainingType was declared of type string, but $trainingType is of type integer. Maybe add a type cast?

This check looks for assignments to scalar types that may be of the wrong type.

To ensure the code behaves as expected, it may be a good idea to add an explicit type cast.

$answer = 42;

$correct = false;

$correct = (bool) $answer;
Loading history...
94
        $this->costFunction = $cost;
95
        $this->penalty = $penalty;
96
    }
97
98
    /**
99
     * Sets the learning rate if gradient descent algorithm is
100
     * selected for training
101
     */
102
    public function setLearningRate(float $learningRate): void
103
    {
104
        $this->learningRate = $learningRate;
105
    }
106
107
    /**
108
     * Lambda (λ) parameter of regularization term. If 0 is given,
109
     * then the regularization term is cancelled
110
     */
111
    public function setLambda(float $lambda): void
112
    {
113
        $this->lambda = $lambda;
114
    }
115
116
    /**
117
     * Adapts the weights with respect to given samples and targets
118
     * by use of selected solver
119
     *
120
     * @throws \Exception
121
     */
122
    protected function runTraining(array $samples, array $targets)
123
    {
124
        $callback = $this->getCostFunction();
125
126
        switch ($this->trainingType) {
127
            case self::BATCH_TRAINING:
128
                return $this->runGradientDescent($samples, $targets, $callback, true);
129
130
            case self::ONLINE_TRAINING:
131
                return $this->runGradientDescent($samples, $targets, $callback, false);
132
133
            case self::CONJUGATE_GRAD_TRAINING:
134
                return $this->runConjugateGradient($samples, $targets, $callback);
135
136
            default:
137
                throw new Exception('Logistic regression has invalid training type: %s.', $this->trainingType);
138
        }
139
    }
140
141
    /**
142
     * Executes Conjugate Gradient method to optimize the weights of the LogReg model
143
     */
144
    protected function runConjugateGradient(array $samples, array $targets, Closure $gradientFunc): void
145
    {
146
        if (empty($this->optimizer)) {
147
            $this->optimizer = (new ConjugateGradient($this->featureCount))
148
                ->setMaxIterations($this->maxIterations);
149
        }
150
151
        $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
152
        $this->costValues = $this->optimizer->getCostValues();
153
    }
154
155
    /**
156
     * Returns the appropriate callback function for the selected cost function
157
     *
158
     * @throws \Exception
159
     */
160
    protected function getCostFunction(): Closure
161
    {
162
        $penalty = 0;
163
        if ($this->penalty == 'L2') {
164
            $penalty = $this->lambda;
165
        }
166
167
        switch ($this->costFunction) {
168
            case 'log':
169
                /*
170
                 * Negative of Log-likelihood cost function to be minimized:
171
                 *		J(x) = ∑( - y . log(h(x)) - (1 - y) . log(1 - h(x)))
172
                 *
173
                 * If regularization term is given, then it will be added to the cost:
174
                 *		for L2 : J(x) = J(x) +  λ/m . w
175
                 *
176
                 * The gradient of the cost function to be used with gradient descent:
177
                 *		∇J(x) = -(y - h(x)) = (h(x) - y)
178
                 */
179
                $callback = function ($weights, $sample, $y) use ($penalty) {
180
                    $this->weights = $weights;
181
                    $hX = $this->output($sample);
182
183
                    // In cases where $hX = 1 or $hX = 0, the log-likelihood
184
                    // value will give a NaN, so we fix these values
185
                    if ($hX == 1) {
186
                        $hX = 1 - 1e-10;
187
                    }
188
189
                    if ($hX == 0) {
190
                        $hX = 1e-10;
191
                    }
192
193
                    $y = $y < 0 ? 0 : 1;
194
195
                    $error = -$y * log($hX) - (1 - $y) * log(1 - $hX);
196
                    $gradient = $hX - $y;
197
198
                    return [$error, $gradient, $penalty];
199
                };
200
201
                return $callback;
202
203
            case 'sse':
204
                /*
205
                 * Sum of squared errors or least squared errors cost function:
206
                 *		J(x) = ∑ (y - h(x))^2
207
                 *
208
                 * If regularization term is given, then it will be added to the cost:
209
                 *		for L2 : J(x) = J(x) +  λ/m . w
210
                 *
211
                 * The gradient of the cost function:
212
                 *		∇J(x) = -(h(x) - y) . h(x) . (1 - h(x))
213
                 */
214
                $callback = function ($weights, $sample, $y) use ($penalty) {
215
                    $this->weights = $weights;
216
                    $hX = $this->output($sample);
217
218
                    $y = $y < 0 ? 0 : 1;
219
220
                    $error = ($y - $hX) ** 2;
221
                    $gradient = -($y - $hX) * $hX * (1 - $hX);
222
223
                    return [$error, $gradient, $penalty];
224
                };
225
226
                return $callback;
227
228
            default:
229
                throw new Exception(sprintf('Logistic regression has invalid cost function: %s.', $this->costFunction));
230
        }
231
    }
232
233
    /**
234
     * Returns the output of the network, a float value between 0.0 and 1.0
235
     */
236
    protected function output(array $sample): float
237
    {
238
        $sum = parent::output($sample);
239
240
        return 1.0 / (1.0 + exp(-$sum));
241
    }
242
243
    /**
244
     * Returns the class value (either -1 or 1) for the given input
245
     */
246
    protected function outputClass(array $sample): int
247
    {
248
        $output = $this->output($sample);
249
250
        if ($output > 0.5) {
251
            return 1;
252
        }
253
254
        return -1;
255
    }
256
257
    /**
258
     * Returns the probability of the sample of belonging to the given label.
259
     *
260
     * The probability is simply taken as the distance of the sample
261
     * to the decision plane.
262
     *
263
     * @param mixed $label
264
     */
265
    protected function predictProbability(array $sample, $label): float
266
    {
267
        $sample = $this->checkNormalizedSample($sample);
268
        $probability = $this->output($sample);
269
270
        if (array_search($label, $this->labels, true) > 0) {
271
            return $probability;
272
        }
273
274
        return 1 - $probability;
275
    }
276
}
277