Passed
Push — master ( fbbe5c...a34811 )
by Arkadiusz
07:00
created

src/Phpml/Classification/Linear/Perceptron.php (1 issue)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Closure;
8
use Exception;
9
use Phpml\Classification\Classifier;
10
use Phpml\Helper\OneVsRest;
11
use Phpml\Helper\Optimizer\GD;
12
use Phpml\Helper\Optimizer\StochasticGD;
13
use Phpml\Helper\Predictable;
14
use Phpml\IncrementalEstimator;
15
use Phpml\Preprocessing\Normalizer;
16
17
class Perceptron implements Classifier, IncrementalEstimator
18
{
19
    use Predictable, OneVsRest;
20
21
    /**
22
     * @var \Phpml\Helper\Optimizer\Optimizer
23
     */
24
    protected $optimizer;
25
26
    /**
27
     * @var array
28
     */
29
    protected $labels = [];
30
31
    /**
32
     * @var int
33
     */
34
    protected $featureCount = 0;
35
36
    /**
37
     * @var array
38
     */
39
    protected $weights = [];
40
41
    /**
42
     * @var float
43
     */
44
    protected $learningRate;
45
46
    /**
47
     * @var int
48
     */
49
    protected $maxIterations;
50
51
    /**
52
     * @var Normalizer
53
     */
54
    protected $normalizer;
55
56
    /**
57
     * @var bool
58
     */
59
    protected $enableEarlyStop = true;
60
61
    /**
62
     * @var array
63
     */
64
    protected $costValues = [];
65
66
    /**
67
     * Initalize a perceptron classifier with given learning rate and maximum
68
     * number of iterations used while training the perceptron
69
     *
70
     * @param float $learningRate    Value between 0.0(exclusive) and 1.0(inclusive)
71
     * @param int   $maxIterations   Must be at least 1
72
     *
73
     * @throws \Exception
74
     */
75
    public function __construct(float $learningRate = 0.001, int $maxIterations = 1000, bool $normalizeInputs = true)
76
    {
77
        if ($learningRate <= 0.0 || $learningRate > 1.0) {
78
            throw new Exception('Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)');
79
        }
80
81
        if ($maxIterations <= 0) {
82
            throw new Exception('Maximum number of iterations must be an integer greater than 0');
83
        }
84
85
        if ($normalizeInputs) {
86
            $this->normalizer = new Normalizer(Normalizer::NORM_STD);
87
        }
88
89
        $this->learningRate = $learningRate;
90
        $this->maxIterations = $maxIterations;
91
    }
92
93
    public function partialTrain(array $samples, array $targets, array $labels = []): void
94
    {
95
        $this->trainByLabel($samples, $targets, $labels);
96
    }
97
98
    public function trainBinary(array $samples, array $targets, array $labels): void
99
    {
100
        if ($this->normalizer) {
101
            $this->normalizer->transform($samples);
102
        }
103
104
        // Set all target values to either -1 or 1
105
        $this->labels = [
106
            1 => $labels[0],
107
            -1 => $labels[1],
108
        ];
109
        foreach ($targets as $key => $target) {
110
            $targets[$key] = (string) $target == (string) $this->labels[1] ? 1 : -1;
111
        }
112
113
        // Set samples and feature count vars
114
        $this->featureCount = count($samples[0]);
115
116
        $this->runTraining($samples, $targets);
117
    }
118
119
    /**
120
     * Normally enabling early stopping for the optimization procedure may
121
     * help saving processing time while in some cases it may result in
122
     * premature convergence.<br>
123
     *
124
     * If "false" is given, the optimization procedure will always be executed
125
     * for $maxIterations times
126
     *
127
     * @return $this
128
     */
129
    public function setEarlyStop(bool $enable = true)
130
    {
131
        $this->enableEarlyStop = $enable;
132
133
        return $this;
134
    }
135
136
    /**
137
     * Returns the cost values obtained during the training.
138
     */
139
    public function getCostValues(): array
140
    {
141
        return $this->costValues;
142
    }
143
144
    protected function resetBinary(): void
145
    {
146
        $this->labels = [];
147
        $this->optimizer = null;
148
        $this->featureCount = 0;
149
        $this->weights = null;
0 ignored issues
show
Documentation Bug introduced by
It seems like null of type null is incompatible with the declared type array of property $weights.

Our type inference engine has found an assignment to a property that is incompatible with the declared type of that property.

Either this assignment is in error or the assigned type should be added to the documentation/type hint for that property..

Loading history...
150
        $this->costValues = [];
151
    }
152
153
    /**
154
     * Trains the perceptron model with Stochastic Gradient Descent optimization
155
     * to get the correct set of weights
156
     */
157
    protected function runTraining(array $samples, array $targets)
158
    {
159
        // The cost function is the sum of squares
160
        $callback = function ($weights, $sample, $target) {
161
            $this->weights = $weights;
162
163
            $prediction = $this->outputClass($sample);
164
            $gradient = $prediction - $target;
165
            $error = $gradient ** 2;
166
167
            return [$error, $gradient];
168
        };
169
170
        $this->runGradientDescent($samples, $targets, $callback);
171
    }
172
173
    /**
174
     * Executes a Gradient Descent algorithm for
175
     * the given cost function
176
     */
177
    protected function runGradientDescent(array $samples, array $targets, Closure $gradientFunc, bool $isBatch = false): void
178
    {
179
        $class = $isBatch ? GD::class : StochasticGD::class;
180
181
        if (empty($this->optimizer)) {
182
            $this->optimizer = (new $class($this->featureCount))
183
                ->setLearningRate($this->learningRate)
184
                ->setMaxIterations($this->maxIterations)
185
                ->setChangeThreshold(1e-6)
186
                ->setEarlyStop($this->enableEarlyStop);
187
        }
188
189
        $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
190
        $this->costValues = $this->optimizer->getCostValues();
191
    }
192
193
    /**
194
     * Checks if the sample should be normalized and if so, returns the
195
     * normalized sample
196
     */
197
    protected function checkNormalizedSample(array $sample): array
198
    {
199
        if ($this->normalizer) {
200
            $samples = [$sample];
201
            $this->normalizer->transform($samples);
202
            $sample = $samples[0];
203
        }
204
205
        return $sample;
206
    }
207
208
    /**
209
     * Calculates net output of the network as a float value for the given input
210
     *
211
     * @return int|float
212
     */
213
    protected function output(array $sample)
214
    {
215
        $sum = 0;
216
        foreach ($this->weights as $index => $w) {
217
            if ($index == 0) {
218
                $sum += $w;
219
            } else {
220
                $sum += $w * $sample[$index - 1];
221
            }
222
        }
223
224
        return $sum;
225
    }
226
227
    /**
228
     * Returns the class value (either -1 or 1) for the given input
229
     */
230
    protected function outputClass(array $sample): int
231
    {
232
        return $this->output($sample) > 0 ? 1 : -1;
233
    }
234
235
    /**
236
     * Returns the probability of the sample of belonging to the given label.
237
     *
238
     * The probability is simply taken as the distance of the sample
239
     * to the decision plane.
240
     *
241
     * @param mixed $label
242
     */
243
    protected function predictProbability(array $sample, $label): float
244
    {
245
        $predicted = $this->predictSampleBinary($sample);
246
247
        if ((string) $predicted == (string) $label) {
248
            $sample = $this->checkNormalizedSample($sample);
249
250
            return (float) abs($this->output($sample));
251
        }
252
253
        return 0.0;
254
    }
255
256
    /**
257
     * @return mixed
258
     */
259
    protected function predictSampleBinary(array $sample)
260
    {
261
        $sample = $this->checkNormalizedSample($sample);
262
263
        $predictedClass = $this->outputClass($sample);
264
265
        return $this->labels[$predictedClass];
266
    }
267
}
268