Passed
Push — master ( fbbe5c...a34811 )
by Arkadiusz
07:00
created

src/Phpml/Classification/Linear/Perceptron.php (1 issue)

assigning incompatible types to properties.

Bug Documentation Major

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Closure;
8
use Exception;
9
use Phpml\Classification\Classifier;
10
use Phpml\Helper\OneVsRest;
11
use Phpml\Helper\Optimizer\GD;
12
use Phpml\Helper\Optimizer\StochasticGD;
13
use Phpml\Helper\Predictable;
14
use Phpml\IncrementalEstimator;
15
use Phpml\Preprocessing\Normalizer;
16
17
class Perceptron implements Classifier, IncrementalEstimator
18
{
19
    use Predictable, OneVsRest;
20
21
    /**
22
     * @var \Phpml\Helper\Optimizer\Optimizer
23
     */
24
    protected $optimizer;
25
26
    /**
27
     * @var array
28
     */
29
    protected $labels = [];
30
31
    /**
32
     * @var int
33
     */
34
    protected $featureCount = 0;
35
36
    /**
37
     * @var array
38
     */
39
    protected $weights = [];
40
41
    /**
42
     * @var float
43
     */
44
    protected $learningRate;
45
46
    /**
47
     * @var int
48
     */
49
    protected $maxIterations;
50
51
    /**
52
     * @var Normalizer
53
     */
54
    protected $normalizer;
55
56
    /**
57
     * @var bool
58
     */
59
    protected $enableEarlyStop = true;
60
61
    /**
62
     * @var array
63
     */
64
    protected $costValues = [];
65
66
    /**
67
     * Initalize a perceptron classifier with given learning rate and maximum
68
     * number of iterations used while training the perceptron
69
     *
70
     * @param float $learningRate    Value between 0.0(exclusive) and 1.0(inclusive)
71
     * @param int   $maxIterations   Must be at least 1
72
     *
73
     * @throws \Exception
74
     */
75
    public function __construct(float $learningRate = 0.001, int $maxIterations = 1000, bool $normalizeInputs = true)
76
    {
77
        if ($learningRate <= 0.0 || $learningRate > 1.0) {
78
            throw new Exception('Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)');
79
        }
80
81
        if ($maxIterations <= 0) {
82
            throw new Exception('Maximum number of iterations must be an integer greater than 0');
83
        }
84
85
        if ($normalizeInputs) {
86
            $this->normalizer = new Normalizer(Normalizer::NORM_STD);
87
        }
88
89
        $this->learningRate = $learningRate;
90
        $this->maxIterations = $maxIterations;
91
    }
92
93
    public function partialTrain(array $samples, array $targets, array $labels = []): void
94
    {
95
        $this->trainByLabel($samples, $targets, $labels);
96
    }
97
98
    public function trainBinary(array $samples, array $targets, array $labels): void
99
    {
100
        if ($this->normalizer) {
101
            $this->normalizer->transform($samples);
102
        }
103
104
        // Set all target values to either -1 or 1
105
        $this->labels = [
106
            1 => $labels[0],
107
            -1 => $labels[1],
108
        ];
109
        foreach ($targets as $key => $target) {
110
            $targets[$key] = (string) $target == (string) $this->labels[1] ? 1 : -1;
111
        }
112
113
        // Set samples and feature count vars
114
        $this->featureCount = count($samples[0]);
115
116
        $this->runTraining($samples, $targets);
117
    }
118
119
    /**
120
     * Normally enabling early stopping for the optimization procedure may
121
     * help saving processing time while in some cases it may result in
122
     * premature convergence.<br>
123
     *
124
     * If "false" is given, the optimization procedure will always be executed
125
     * for $maxIterations times
126
     *
127
     * @return $this
128
     */
129
    public function setEarlyStop(bool $enable = true)
130
    {
131
        $this->enableEarlyStop = $enable;
132
133
        return $this;
134
    }
135
136
    /**
137
     * Returns the cost values obtained during the training.
138
     */
139
    public function getCostValues(): array
140
    {
141
        return $this->costValues;
142
    }
143
144
    protected function resetBinary(): void
145
    {
146
        $this->labels = [];
147
        $this->optimizer = null;
148
        $this->featureCount = 0;
149
        $this->weights = null;
0 ignored issues
show
Documentation Bug introduced by
It seems like null of type null is incompatible with the declared type array of property $weights.

Our type inference engine has found an assignment to a property that is incompatible with the declared type of that property.

Either this assignment is in error or the assigned type should be added to the documentation/type hint for that property..

Loading history...
150
        $this->costValues = [];
151
    }
152
153
    /**
154
     * Trains the perceptron model with Stochastic Gradient Descent optimization
155
     * to get the correct set of weights
156
     */
157
    protected function runTraining(array $samples, array $targets)
158
    {
159
        // The cost function is the sum of squares
160
        $callback = function ($weights, $sample, $target) {
161
            $this->weights = $weights;
162
163
            $prediction = $this->outputClass($sample);
164
            $gradient = $prediction - $target;
165
            $error = $gradient ** 2;
166
167
            return [$error, $gradient];
168
        };
169
170
        $this->runGradientDescent($samples, $targets, $callback);
171
    }
172
173
    /**
174
     * Executes a Gradient Descent algorithm for
175
     * the given cost function
176
     */
177
    protected function runGradientDescent(array $samples, array $targets, Closure $gradientFunc, bool $isBatch = false): void
178
    {
179
        $class = $isBatch ? GD::class : StochasticGD::class;
180
181
        if (empty($this->optimizer)) {
182
            $this->optimizer = (new $class($this->featureCount))
183
                ->setLearningRate($this->learningRate)
184
                ->setMaxIterations($this->maxIterations)
185
                ->setChangeThreshold(1e-6)
186
                ->setEarlyStop($this->enableEarlyStop);
187
        }
188
189
        $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
190
        $this->costValues = $this->optimizer->getCostValues();
191
    }
192
193
    /**
194
     * Checks if the sample should be normalized and if so, returns the
195
     * normalized sample
196
     */
197
    protected function checkNormalizedSample(array $sample): array
198
    {
199
        if ($this->normalizer) {
200
            $samples = [$sample];
201
            $this->normalizer->transform($samples);
202
            $sample = $samples[0];
203
        }
204
205
        return $sample;
206
    }
207
208
    /**
209
     * Calculates net output of the network as a float value for the given input
210
     *
211
     * @return int|float
212
     */
213
    protected function output(array $sample)
214
    {
215
        $sum = 0;
216
        foreach ($this->weights as $index => $w) {
217
            if ($index == 0) {
218
                $sum += $w;
219
            } else {
220
                $sum += $w * $sample[$index - 1];
221
            }
222
        }
223
224
        return $sum;
225
    }
226
227
    /**
228
     * Returns the class value (either -1 or 1) for the given input
229
     */
230
    protected function outputClass(array $sample): int
231
    {
232
        return $this->output($sample) > 0 ? 1 : -1;
233
    }
234
235
    /**
236
     * Returns the probability of the sample of belonging to the given label.
237
     *
238
     * The probability is simply taken as the distance of the sample
239
     * to the decision plane.
240
     *
241
     * @param mixed $label
242
     */
243
    protected function predictProbability(array $sample, $label): float
244
    {
245
        $predicted = $this->predictSampleBinary($sample);
246
247
        if ((string) $predicted == (string) $label) {
248
            $sample = $this->checkNormalizedSample($sample);
249
250
            return (float) abs($this->output($sample));
251
        }
252
253
        return 0.0;
254
    }
255
256
    /**
257
     * @return mixed
258
     */
259
    protected function predictSampleBinary(array $sample)
260
    {
261
        $sample = $this->checkNormalizedSample($sample);
262
263
        $predictedClass = $this->outputClass($sample);
264
265
        return $this->labels[$predictedClass];
266
    }
267
}
268