Issues (14)

src/Classification/Linear/LogisticRegression.php (1 issue)

Labels
Severity
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Closure;
8
use Exception;
9
use Phpml\Exception\InvalidArgumentException;
10
use Phpml\Helper\Optimizer\ConjugateGradient;
11
12
class LogisticRegression extends Adaline
13
{
14
    /**
15
     * Batch training: Gradient descent algorithm (default)
16
     */
17
    public const BATCH_TRAINING = 1;
18
19
    /**
20
     * Online training: Stochastic gradient descent learning
21
     */
22
    public const ONLINE_TRAINING = 2;
23
24
    /**
25
     * Conjugate Batch: Conjugate Gradient algorithm
26
     */
27
    public const CONJUGATE_GRAD_TRAINING = 3;
28
29
    /**
30
     * Cost function to optimize: 'log' and 'sse' are supported <br>
31
     *  - 'log' : log likelihood <br>
32
     *  - 'sse' : sum of squared errors <br>
33
     *
34
     * @var string
35
     */
36
    protected $costFunction = 'log';
37
38
    /**
39
     * Regularization term: only 'L2' is supported
40
     *
41
     * @var string
42
     */
43
    protected $penalty = 'L2';
44
45
    /**
46
     * Lambda (λ) parameter of regularization term. If λ is set to 0, then
47
     * regularization term is cancelled.
48
     *
49
     * @var float
50
     */
51
    protected $lambda = 0.5;
52
53
    /**
54
     * Initalize a Logistic Regression classifier with maximum number of iterations
55
     * and learning rule to be applied <br>
56
     *
57
     * Maximum number of iterations can be an integer value greater than 0 <br>
58
     * If normalizeInputs is set to true, then every input given to the algorithm will be standardized
59
     * by use of standard deviation and mean calculation <br>
60
     *
61
     * Cost function can be 'log' for log-likelihood and 'sse' for sum of squared errors <br>
62
     *
63
     * Penalty (Regularization term) can be 'L2' or empty string to cancel penalty term
64
     *
65
     * @throws InvalidArgumentException
66
     */
67
    public function __construct(
68
        int $maxIterations = 500,
69
        bool $normalizeInputs = true,
70
        int $trainingType = self::CONJUGATE_GRAD_TRAINING,
71
        string $cost = 'log',
72
        string $penalty = 'L2'
73
    ) {
74
        $trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING);
75
        if (!in_array($trainingType, $trainingTypes, true)) {
76
            throw new InvalidArgumentException(
77
                'Logistic regression can only be trained with '.
78
                'batch (gradient descent), online (stochastic gradient descent) '.
79
                'or conjugate batch (conjugate gradients) algorithms'
80
            );
81
        }
82
83
        if (!in_array($cost, ['log', 'sse'], true)) {
84
            throw new InvalidArgumentException(
85
                "Logistic regression cost function can be one of the following: \n".
86
                "'log' for log-likelihood and 'sse' for sum of squared errors"
87
            );
88
        }
89
90
        if ($penalty !== '' && strtoupper($penalty) !== 'L2') {
91
            throw new InvalidArgumentException('Logistic regression supports only \'L2\' regularization');
92
        }
93
94
        $this->learningRate = 0.001;
95
96
        parent::__construct($this->learningRate, $maxIterations, $normalizeInputs);
97
98
        $this->trainingType = $trainingType;
99
        $this->costFunction = $cost;
100
        $this->penalty = $penalty;
101
    }
102
103
    /**
104
     * Sets the learning rate if gradient descent algorithm is
105
     * selected for training
106
     */
107
    public function setLearningRate(float $learningRate): void
108
    {
109
        $this->learningRate = $learningRate;
110
    }
111
112
    /**
113
     * Lambda (λ) parameter of regularization term. If 0 is given,
114
     * then the regularization term is cancelled
115
     */
116
    public function setLambda(float $lambda): void
117
    {
118
        $this->lambda = $lambda;
119
    }
120
121
    /**
122
     * Adapts the weights with respect to given samples and targets
123
     * by use of selected solver
124
     *
125
     * @throws \Exception
126
     */
127
    protected function runTraining(array $samples, array $targets): void
128
    {
129
        $callback = $this->getCostFunction();
130
131
        switch ($this->trainingType) {
132
            case self::BATCH_TRAINING:
133
                $this->runGradientDescent($samples, $targets, $callback, true);
134
135
                return;
136
137
            case self::ONLINE_TRAINING:
138
                $this->runGradientDescent($samples, $targets, $callback, false);
139
140
                return;
141
142
            case self::CONJUGATE_GRAD_TRAINING:
143
                $this->runConjugateGradient($samples, $targets, $callback);
144
145
                return;
146
147
            default:
148
                // Not reached
149
                throw new Exception(sprintf('Logistic regression has invalid training type: %d.', $this->trainingType));
150
        }
151
    }
152
153
    /**
154
     * Executes Conjugate Gradient method to optimize the weights of the LogReg model
155
     */
156
    protected function runConjugateGradient(array $samples, array $targets, Closure $gradientFunc): void
157
    {
158
        if ($this->optimizer === null) {
159
            $this->optimizer = (new ConjugateGradient($this->featureCount))
160
                ->setMaxIterations($this->maxIterations);
161
        }
162
163
        $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
164
        $this->costValues = $this->optimizer->getCostValues();
0 ignored issues
show
The method getCostValues() does not exist on Phpml\Helper\Optimizer\Optimizer. Since it exists in all sub-types, consider adding an abstract or default implementation to Phpml\Helper\Optimizer\Optimizer. ( Ignorable by Annotation )

If this is a false-positive, you can also ignore this issue in your code via the ignore-call  annotation

164
        /** @scrutinizer ignore-call */ 
165
        $this->costValues = $this->optimizer->getCostValues();
Loading history...
165
    }
166
167
    /**
168
     * Returns the appropriate callback function for the selected cost function
169
     *
170
     * @throws \Exception
171
     */
172
    protected function getCostFunction(): Closure
173
    {
174
        $penalty = 0;
175
        if ($this->penalty === 'L2') {
176
            $penalty = $this->lambda;
177
        }
178
179
        switch ($this->costFunction) {
180
            case 'log':
181
                /*
182
                 * Negative of Log-likelihood cost function to be minimized:
183
                 *		J(x) = ∑( - y . log(h(x)) - (1 - y) . log(1 - h(x)))
184
                 *
185
                 * If regularization term is given, then it will be added to the cost:
186
                 *		for L2 : J(x) = J(x) +  λ/m . w
187
                 *
188
                 * The gradient of the cost function to be used with gradient descent:
189
                 *		∇J(x) = -(y - h(x)) = (h(x) - y)
190
                 */
191
                return function ($weights, $sample, $y) use ($penalty) {
192
                    $this->weights = $weights;
193
                    $hX = $this->output($sample);
194
195
                    // In cases where $hX = 1 or $hX = 0, the log-likelihood
196
                    // value will give a NaN, so we fix these values
197
                    if ($hX == 1) {
198
                        $hX = 1 - 1e-10;
199
                    }
200
201
                    if ($hX == 0) {
202
                        $hX = 1e-10;
203
                    }
204
205
                    $y = $y < 0 ? 0 : 1;
206
207
                    $error = -$y * log($hX) - (1 - $y) * log(1 - $hX);
208
                    $gradient = $hX - $y;
209
210
                    return [$error, $gradient, $penalty];
211
                };
212
            case 'sse':
213
                /*
214
                 * Sum of squared errors or least squared errors cost function:
215
                 *		J(x) = ∑ (y - h(x))^2
216
                 *
217
                 * If regularization term is given, then it will be added to the cost:
218
                 *		for L2 : J(x) = J(x) +  λ/m . w
219
                 *
220
                 * The gradient of the cost function:
221
                 *		∇J(x) = -(h(x) - y) . h(x) . (1 - h(x))
222
                 */
223
                return function ($weights, $sample, $y) use ($penalty) {
224
                    $this->weights = $weights;
225
                    $hX = $this->output($sample);
226
227
                    $y = $y < 0 ? 0 : 1;
228
229
                    $error = (($y - $hX) ** 2);
230
                    $gradient = -($y - $hX) * $hX * (1 - $hX);
231
232
                    return [$error, $gradient, $penalty];
233
                };
234
            default:
235
                // Not reached
236
                throw new Exception(sprintf('Logistic regression has invalid cost function: %s.', $this->costFunction));
237
        }
238
    }
239
240
    /**
241
     * Returns the output of the network, a float value between 0.0 and 1.0
242
     */
243
    protected function output(array $sample): float
244
    {
245
        $sum = parent::output($sample);
246
247
        return 1.0 / (1.0 + exp(-$sum));
248
    }
249
250
    /**
251
     * Returns the class value (either -1 or 1) for the given input
252
     */
253
    protected function outputClass(array $sample): int
254
    {
255
        $output = $this->output($sample);
256
257
        if ($output > 0.5) {
258
            return 1;
259
        }
260
261
        return -1;
262
    }
263
264
    /**
265
     * Returns the probability of the sample of belonging to the given label.
266
     *
267
     * The probability is simply taken as the distance of the sample
268
     * to the decision plane.
269
     *
270
     * @param mixed $label
271
     */
272
    protected function predictProbability(array $sample, $label): float
273
    {
274
        $sample = $this->checkNormalizedSample($sample);
275
        $probability = $this->output($sample);
276
277
        if (array_search($label, $this->labels, true) > 0) {
278
            return $probability;
279
        }
280
281
        return 1 - $probability;
282
    }
283
}
284