Passed
Push — master ( c44f3b...492344 )
by Arkadiusz
02:48
created

LogisticRegression::setLambda()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 4
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
eloc 2
nc 1
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Phpml\Classification\Classifier;
8
use Phpml\Helper\Optimizer\ConjugateGradient;
9
10
class LogisticRegression extends Adaline
11
{
12
13
    /**
14
     * Batch training: Gradient descent algorithm (default)
15
     */
16
    const BATCH_TRAINING    = 1;
17
18
    /**
19
     * Online training: Stochastic gradient descent learning
20
     */
21
    const ONLINE_TRAINING    = 2;
22
23
    /**
24
     * Conjugate Batch: Conjugate Gradient algorithm
25
     */
26
    const CONJUGATE_GRAD_TRAINING = 3;
27
28
    /**
29
     * Cost function to optimize: 'log' and 'sse' are supported <br>
30
     *  - 'log' : log likelihood <br>
31
     *  - 'sse' : sum of squared errors <br>
32
     *
33
     * @var string
34
     */
35
    protected $costFunction = 'sse';
36
37
    /**
38
     * Regularization term: only 'L2' is supported
39
     *
40
     * @var string
41
     */
42
    protected $penalty = 'L2';
43
44
    /**
45
     * Lambda (λ) parameter of regularization term. If λ is set to 0, then
46
     * regularization term is cancelled.
47
     *
48
     * @var float
49
     */
50
    protected $lambda = 0.5;
51
52
    /**
53
     * Initalize a Logistic Regression classifier with maximum number of iterations
54
     * and learning rule to be applied <br>
55
     *
56
     * Maximum number of iterations can be an integer value greater than 0 <br>
57
     * If normalizeInputs is set to true, then every input given to the algorithm will be standardized
58
     * by use of standard deviation and mean calculation <br>
59
     *
60
     * Cost function can be 'log' for log-likelihood and 'sse' for sum of squared errors <br>
61
     *
62
     * Penalty (Regularization term) can be 'L2' or empty string to cancel penalty term
63
     *
64
     * @param int $maxIterations
65
     * @param bool $normalizeInputs
66
     * @param int $trainingType
67
     * @param string $cost
68
     * @param string $penalty
69
     *
70
     * @throws \Exception
71
     */
72
    public function __construct(int $maxIterations = 500, bool $normalizeInputs = true,
73
        int $trainingType = self::CONJUGATE_GRAD_TRAINING, string $cost = 'sse',
74
        string $penalty = 'L2')
75
    {
76
        $trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING);
77
        if (! in_array($trainingType, $trainingTypes)) {
78
            throw new \Exception("Logistic regression can only be trained with " .
79
                "batch (gradient descent), online (stochastic gradient descent) " .
80
                "or conjugate batch (conjugate gradients) algorithms");
81
        }
82
83
        if (! in_array($cost, ['log', 'sse'])) {
84
            throw new \Exception("Logistic regression cost function can be one of the following: \n" .
85
                "'log' for log-likelihood and 'sse' for sum of squared errors");
86
        }
87
88
        if ($penalty != '' && strtoupper($penalty) !== 'L2') {
89
            throw new \Exception("Logistic regression supports only 'L2' regularization");
90
        }
91
92
        $this->learningRate = 0.001;
93
94
        parent::__construct($this->learningRate, $maxIterations, $normalizeInputs);
95
96
        $this->trainingType = $trainingType;
0 ignored issues
show
Documentation Bug introduced by
The property $trainingType was declared of type string, but $trainingType is of type integer. Maybe add a type cast?

This check looks for assignments to scalar types that may be of the wrong type.

To ensure the code behaves as expected, it may be a good idea to add an explicit type cast.

$answer = 42;

$correct = false;

$correct = (bool) $answer;
Loading history...
97
        $this->costFunction = $cost;
98
        $this->penalty = $penalty;
99
    }
100
101
    /**
102
     * Sets the learning rate if gradient descent algorithm is
103
     * selected for training
104
     *
105
     * @param float $learningRate
106
     */
107
    public function setLearningRate(float $learningRate)
108
    {
109
        $this->learningRate = $learningRate;
110
    }
111
112
    /**
113
     * Lambda (λ) parameter of regularization term. If 0 is given,
114
     * then the regularization term is cancelled
115
     *
116
     * @param float $lambda
117
     */
118
    public function setLambda(float $lambda)
119
    {
120
        $this->lambda = $lambda;
121
    }
122
123
    /**
124
     * Adapts the weights with respect to given samples and targets
125
     * by use of selected solver
126
     */
127
    protected function runTraining()
128
    {
129
        $callback = $this->getCostFunction();
130
131
        switch ($this->trainingType) {
132
            case self::BATCH_TRAINING:
133
                return $this->runGradientDescent($callback, true);
0 ignored issues
show
Bug introduced by
It seems like $callback defined by $this->getCostFunction() on line 129 can be null; however, Phpml\Classification\Lin...n::runGradientDescent() does not accept null, maybe add an additional type check?

Unless you are absolutely sure that the expression can never be null because of other conditions, we strongly recommend to add an additional type check to your code:

/** @return stdClass|null */
function mayReturnNull() { }

function doesNotAcceptNull(stdClass $x) { }

// With potential error.
function withoutCheck() {
    $x = mayReturnNull();
    doesNotAcceptNull($x); // Potential error here.
}

// Safe - Alternative 1
function withCheck1() {
    $x = mayReturnNull();
    if ( ! $x instanceof stdClass) {
        throw new \LogicException('$x must be defined.');
    }
    doesNotAcceptNull($x);
}

// Safe - Alternative 2
function withCheck2() {
    $x = mayReturnNull();
    if ($x instanceof stdClass) {
        doesNotAcceptNull($x);
    }
}
Loading history...
134
135
            case self::ONLINE_TRAINING:
136
                return $this->runGradientDescent($callback, false);
0 ignored issues
show
Bug introduced by
It seems like $callback defined by $this->getCostFunction() on line 129 can be null; however, Phpml\Classification\Lin...n::runGradientDescent() does not accept null, maybe add an additional type check?

Unless you are absolutely sure that the expression can never be null because of other conditions, we strongly recommend to add an additional type check to your code:

/** @return stdClass|null */
function mayReturnNull() { }

function doesNotAcceptNull(stdClass $x) { }

// With potential error.
function withoutCheck() {
    $x = mayReturnNull();
    doesNotAcceptNull($x); // Potential error here.
}

// Safe - Alternative 1
function withCheck1() {
    $x = mayReturnNull();
    if ( ! $x instanceof stdClass) {
        throw new \LogicException('$x must be defined.');
    }
    doesNotAcceptNull($x);
}

// Safe - Alternative 2
function withCheck2() {
    $x = mayReturnNull();
    if ($x instanceof stdClass) {
        doesNotAcceptNull($x);
    }
}
Loading history...
137
138
            case self::CONJUGATE_GRAD_TRAINING:
139
                return $this->runConjugateGradient($callback);
0 ignored issues
show
Bug introduced by
It seems like $callback defined by $this->getCostFunction() on line 129 can be null; however, Phpml\Classification\Lin...:runConjugateGradient() does not accept null, maybe add an additional type check?

Unless you are absolutely sure that the expression can never be null because of other conditions, we strongly recommend to add an additional type check to your code:

/** @return stdClass|null */
function mayReturnNull() { }

function doesNotAcceptNull(stdClass $x) { }

// With potential error.
function withoutCheck() {
    $x = mayReturnNull();
    doesNotAcceptNull($x); // Potential error here.
}

// Safe - Alternative 1
function withCheck1() {
    $x = mayReturnNull();
    if ( ! $x instanceof stdClass) {
        throw new \LogicException('$x must be defined.');
    }
    doesNotAcceptNull($x);
}

// Safe - Alternative 2
function withCheck2() {
    $x = mayReturnNull();
    if ($x instanceof stdClass) {
        doesNotAcceptNull($x);
    }
}
Loading history...
140
        }
141
    }
142
143
    /**
144
     * Executes Conjugate Gradient method to optimize the
145
     * weights of the LogReg model
146
     */
147
    protected function runConjugateGradient(\Closure $gradientFunc)
148
    {
149
        $optimizer = (new ConjugateGradient($this->featureCount))
150
            ->setMaxIterations($this->maxIterations);
151
152
        $this->weights = $optimizer->runOptimization($this->samples, $this->targets, $gradientFunc);
153
        $this->costValues = $optimizer->getCostValues();
154
    }
155
156
    /**
157
     * Returns the appropriate callback function for the selected cost function
158
     *
159
     * @return \Closure
160
     */
161
    protected function getCostFunction()
162
    {
163
        $penalty = 0;
164
        if ($this->penalty == 'L2') {
165
            $penalty = $this->lambda;
166
        }
167
168
        switch ($this->costFunction) {
169
            case 'log':
170
                /*
171
                 * Negative of Log-likelihood cost function to be minimized:
172
                 *		J(x) = ∑( - y . log(h(x)) - (1 - y) . log(1 - h(x)))
173
                 *
174
                 * If regularization term is given, then it will be added to the cost:
175
                 *		for L2 : J(x) = J(x) +  λ/m . w
176
                 *
177
                 * The gradient of the cost function to be used with gradient descent:
178
                 *		∇J(x) = -(y - h(x)) = (h(x) - y)
179
                 */
180
                $callback = function ($weights, $sample, $y) use ($penalty) {
181
                    $this->weights = $weights;
182
                    $hX = $this->output($sample);
183
184
                    // In cases where $hX = 1 or $hX = 0, the log-likelihood
185
                    // value will give a NaN, so we fix these values
186
                    if ($hX == 1) {
187
                        $hX = 1 - 1e-10;
188
                    }
189
                    if ($hX == 0) {
190
                        $hX = 1e-10;
191
                    }
192
                    $error = -$y * log($hX) - (1 - $y) * log(1 - $hX);
193
                    $gradient = $hX - $y;
194
195
                    return [$error, $gradient, $penalty];
196
                };
197
198
                return $callback;
199
200
            case 'sse':
201
                /**
202
                 * Sum of squared errors or least squared errors cost function:
203
                 *		J(x) = ∑ (y - h(x))^2
204
                 *
205
                 * If regularization term is given, then it will be added to the cost:
206
                 *		for L2 : J(x) = J(x) +  λ/m . w
207
                 *
208
                 * The gradient of the cost function:
209
                 *		∇J(x) = -(h(x) - y) . h(x) . (1 - h(x))
210
                 */
211
                $callback = function ($weights, $sample, $y) use ($penalty) {
212
                    $this->weights = $weights;
213
                    $hX = $this->output($sample);
214
215
                    $error = ($y - $hX) ** 2;
216
                    $gradient = -($y - $hX) * $hX * (1 - $hX);
217
218
                    return [$error, $gradient, $penalty];
219
                };
220
221
                return $callback;
222
        }
223
    }
224
225
    /**
226
     * Returns the output of the network, a float value between 0.0 and 1.0
227
     *
228
     * @param array $sample
229
     *
230
     * @return float
231
     */
232
    protected function output(array $sample)
233
    {
234
        $sum = parent::output($sample);
235
236
        return 1.0 / (1.0 + exp(-$sum));
237
    }
238
239
    /**
240
     * Returns the class value (either -1 or 1) for the given input
241
     *
242
     * @param array $sample
243
     * @return int
244
     */
245
    protected function outputClass(array $sample)
246
    {
247
        $output = $this->output($sample);
248
249
        if (round($output) > 0.5) {
250
            return 1;
251
        }
252
253
        return -1;
254
    }
255
256
    /**
257
     * Returns the probability of the sample of belonging to the given label.
258
     *
259
     * The probability is simply taken as the distance of the sample
260
     * to the decision plane.
261
     *
262
     * @param array $sample
263
     * @param mixed $label
264
     */
265 View Code Duplication
    protected function predictProbability(array $sample, $label)
0 ignored issues
show
Duplication introduced by
This method seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
266
    {
267
        $predicted = $this->predictSampleBinary($sample);
268
269
        if (strval($predicted) == strval($label)) {
270
            $sample = $this->checkNormalizedSample($sample);
271
            return abs($this->output($sample) - 0.5);
272
        }
273
274
        return 0.0;
275
    }
276
}
277