Passed
Push — master ( fbbe5c...a34811 )
by Arkadiusz
07:00
created

LogisticRegression::runTraining()   B

Complexity

Conditions 4
Paths 4

Size

Total Lines 24
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 24
rs 8.6845
c 0
b 0
f 0
cc 4
eloc 14
nc 4
nop 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Closure;
8
use Exception;
9
use Phpml\Helper\Optimizer\ConjugateGradient;
10
11
class LogisticRegression extends Adaline
12
{
13
    /**
14
     * Batch training: Gradient descent algorithm (default)
15
     */
16
    public const BATCH_TRAINING = 1;
17
18
    /**
19
     * Online training: Stochastic gradient descent learning
20
     */
21
    public const ONLINE_TRAINING = 2;
22
23
    /**
24
     * Conjugate Batch: Conjugate Gradient algorithm
25
     */
26
    public const CONJUGATE_GRAD_TRAINING = 3;
27
28
    /**
29
     * Cost function to optimize: 'log' and 'sse' are supported <br>
30
     *  - 'log' : log likelihood <br>
31
     *  - 'sse' : sum of squared errors <br>
32
     *
33
     * @var string
34
     */
35
    protected $costFunction = 'log';
36
37
    /**
38
     * Regularization term: only 'L2' is supported
39
     *
40
     * @var string
41
     */
42
    protected $penalty = 'L2';
43
44
    /**
45
     * Lambda (λ) parameter of regularization term. If λ is set to 0, then
46
     * regularization term is cancelled.
47
     *
48
     * @var float
49
     */
50
    protected $lambda = 0.5;
51
52
    /**
53
     * Initalize a Logistic Regression classifier with maximum number of iterations
54
     * and learning rule to be applied <br>
55
     *
56
     * Maximum number of iterations can be an integer value greater than 0 <br>
57
     * If normalizeInputs is set to true, then every input given to the algorithm will be standardized
58
     * by use of standard deviation and mean calculation <br>
59
     *
60
     * Cost function can be 'log' for log-likelihood and 'sse' for sum of squared errors <br>
61
     *
62
     * Penalty (Regularization term) can be 'L2' or empty string to cancel penalty term
63
     *
64
     * @throws \Exception
65
     */
66
    public function __construct(
67
        int $maxIterations = 500,
68
        bool $normalizeInputs = true,
69
        int $trainingType = self::CONJUGATE_GRAD_TRAINING,
70
        string $cost = 'log',
71
        string $penalty = 'L2'
72
    ) {
73
        $trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING);
74
        if (!in_array($trainingType, $trainingTypes)) {
75
            throw new Exception('Logistic regression can only be trained with '.
76
                'batch (gradient descent), online (stochastic gradient descent) '.
77
                'or conjugate batch (conjugate gradients) algorithms');
78
        }
79
80
        if (!in_array($cost, ['log', 'sse'])) {
81
            throw new Exception("Logistic regression cost function can be one of the following: \n".
82
                "'log' for log-likelihood and 'sse' for sum of squared errors");
83
        }
84
85
        if ($penalty != '' && strtoupper($penalty) !== 'L2') {
86
            throw new Exception("Logistic regression supports only 'L2' regularization");
87
        }
88
89
        $this->learningRate = 0.001;
90
91
        parent::__construct($this->learningRate, $maxIterations, $normalizeInputs);
92
93
        $this->trainingType = $trainingType;
94
        $this->costFunction = $cost;
95
        $this->penalty = $penalty;
96
    }
97
98
    /**
99
     * Sets the learning rate if gradient descent algorithm is
100
     * selected for training
101
     */
102
    public function setLearningRate(float $learningRate): void
103
    {
104
        $this->learningRate = $learningRate;
105
    }
106
107
    /**
108
     * Lambda (λ) parameter of regularization term. If 0 is given,
109
     * then the regularization term is cancelled
110
     */
111
    public function setLambda(float $lambda): void
112
    {
113
        $this->lambda = $lambda;
114
    }
115
116
    /**
117
     * Adapts the weights with respect to given samples and targets
118
     * by use of selected solver
119
     *
120
     * @throws \Exception
121
     */
122
    protected function runTraining(array $samples, array $targets): void
123
    {
124
        $callback = $this->getCostFunction();
125
126
        switch ($this->trainingType) {
127
            case self::BATCH_TRAINING:
128
                $this->runGradientDescent($samples, $targets, $callback, true);
129
130
                return;
131
132
            case self::ONLINE_TRAINING:
133
                $this->runGradientDescent($samples, $targets, $callback, false);
134
135
                return;
136
137
            case self::CONJUGATE_GRAD_TRAINING:
138
                $this->runConjugateGradient($samples, $targets, $callback);
139
140
                return;
141
142
            default:
143
                throw new Exception('Logistic regression has invalid training type: %s.', $this->trainingType);
144
        }
145
    }
146
147
    /**
148
     * Executes Conjugate Gradient method to optimize the weights of the LogReg model
149
     */
150
    protected function runConjugateGradient(array $samples, array $targets, Closure $gradientFunc): void
151
    {
152
        if ($this->optimizer === null) {
153
            $this->optimizer = (new ConjugateGradient($this->featureCount))
154
                ->setMaxIterations($this->maxIterations);
155
        }
156
157
        $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
158
        $this->costValues = $this->optimizer->getCostValues();
159
    }
160
161
    /**
162
     * Returns the appropriate callback function for the selected cost function
163
     *
164
     * @throws \Exception
165
     */
166
    protected function getCostFunction(): Closure
167
    {
168
        $penalty = 0;
169
        if ($this->penalty == 'L2') {
170
            $penalty = $this->lambda;
171
        }
172
173
        switch ($this->costFunction) {
174
            case 'log':
175
                /*
176
                 * Negative of Log-likelihood cost function to be minimized:
177
                 *		J(x) = ∑( - y . log(h(x)) - (1 - y) . log(1 - h(x)))
178
                 *
179
                 * If regularization term is given, then it will be added to the cost:
180
                 *		for L2 : J(x) = J(x) +  λ/m . w
181
                 *
182
                 * The gradient of the cost function to be used with gradient descent:
183
                 *		∇J(x) = -(y - h(x)) = (h(x) - y)
184
                 */
185
                $callback = function ($weights, $sample, $y) use ($penalty) {
186
                    $this->weights = $weights;
187
                    $hX = $this->output($sample);
188
189
                    // In cases where $hX = 1 or $hX = 0, the log-likelihood
190
                    // value will give a NaN, so we fix these values
191
                    if ($hX == 1) {
192
                        $hX = 1 - 1e-10;
193
                    }
194
195
                    if ($hX == 0) {
196
                        $hX = 1e-10;
197
                    }
198
199
                    $y = $y < 0 ? 0 : 1;
200
201
                    $error = -$y * log($hX) - (1 - $y) * log(1 - $hX);
202
                    $gradient = $hX - $y;
203
204
                    return [$error, $gradient, $penalty];
205
                };
206
207
                return $callback;
208
209
            case 'sse':
210
                /*
211
                 * Sum of squared errors or least squared errors cost function:
212
                 *		J(x) = ∑ (y - h(x))^2
213
                 *
214
                 * If regularization term is given, then it will be added to the cost:
215
                 *		for L2 : J(x) = J(x) +  λ/m . w
216
                 *
217
                 * The gradient of the cost function:
218
                 *		∇J(x) = -(h(x) - y) . h(x) . (1 - h(x))
219
                 */
220
                $callback = function ($weights, $sample, $y) use ($penalty) {
221
                    $this->weights = $weights;
222
                    $hX = $this->output($sample);
223
224
                    $y = $y < 0 ? 0 : 1;
225
226
                    $error = ($y - $hX) ** 2;
227
                    $gradient = -($y - $hX) * $hX * (1 - $hX);
228
229
                    return [$error, $gradient, $penalty];
230
                };
231
232
                return $callback;
233
234
            default:
235
                throw new Exception(sprintf('Logistic regression has invalid cost function: %s.', $this->costFunction));
236
        }
237
    }
238
239
    /**
240
     * Returns the output of the network, a float value between 0.0 and 1.0
241
     */
242
    protected function output(array $sample): float
243
    {
244
        $sum = parent::output($sample);
245
246
        return 1.0 / (1.0 + exp(-$sum));
247
    }
248
249
    /**
250
     * Returns the class value (either -1 or 1) for the given input
251
     */
252
    protected function outputClass(array $sample): int
253
    {
254
        $output = $this->output($sample);
255
256
        if ($output > 0.5) {
257
            return 1;
258
        }
259
260
        return -1;
261
    }
262
263
    /**
264
     * Returns the probability of the sample of belonging to the given label.
265
     *
266
     * The probability is simply taken as the distance of the sample
267
     * to the decision plane.
268
     *
269
     * @param mixed $label
270
     */
271
    protected function predictProbability(array $sample, $label): float
272
    {
273
        $sample = $this->checkNormalizedSample($sample);
274
        $probability = $this->output($sample);
275
276
        if (array_search($label, $this->labels, true) > 0) {
277
            return $probability;
278
        }
279
280
        return 1 - $probability;
281
    }
282
}
283