Passed
Push — master ( fbbe5c...a34811 )
by Arkadiusz
07:00
created

src/Phpml/Classification/Linear/Perceptron.php (1 issue)

Upgrade to new PHP Analysis Engine

These results are based on our legacy PHP analysis, consider migrating to our new PHP analysis engine instead. Learn more

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Closure;
8
use Exception;
9
use Phpml\Classification\Classifier;
10
use Phpml\Helper\OneVsRest;
11
use Phpml\Helper\Optimizer\GD;
12
use Phpml\Helper\Optimizer\StochasticGD;
13
use Phpml\Helper\Predictable;
14
use Phpml\IncrementalEstimator;
15
use Phpml\Preprocessing\Normalizer;
16
17
class Perceptron implements Classifier, IncrementalEstimator
18
{
19
    use Predictable, OneVsRest;
20
21
    /**
22
     * @var \Phpml\Helper\Optimizer\Optimizer|GD|StochasticGD|null
23
     */
24
    protected $optimizer;
25
26
    /**
27
     * @var array
28
     */
29
    protected $labels = [];
30
31
    /**
32
     * @var int
33
     */
34
    protected $featureCount = 0;
35
36
    /**
37
     * @var array|null
38
     */
39
    protected $weights = [];
40
41
    /**
42
     * @var float
43
     */
44
    protected $learningRate;
45
46
    /**
47
     * @var int
48
     */
49
    protected $maxIterations;
50
51
    /**
52
     * @var Normalizer
53
     */
54
    protected $normalizer;
55
56
    /**
57
     * @var bool
58
     */
59
    protected $enableEarlyStop = true;
60
61
    /**
62
     * @var array
63
     */
64
    protected $costValues = [];
65
66
    /**
67
     * Initalize a perceptron classifier with given learning rate and maximum
68
     * number of iterations used while training the perceptron
69
     *
70
     * @param float $learningRate  Value between 0.0(exclusive) and 1.0(inclusive)
71
     * @param int   $maxIterations Must be at least 1
72
     *
73
     * @throws \Exception
74
     */
75
    public function __construct(float $learningRate = 0.001, int $maxIterations = 1000, bool $normalizeInputs = true)
76
    {
77
        if ($learningRate <= 0.0 || $learningRate > 1.0) {
78
            throw new Exception('Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)');
79
        }
80
81
        if ($maxIterations <= 0) {
82
            throw new Exception('Maximum number of iterations must be an integer greater than 0');
83
        }
84
85
        if ($normalizeInputs) {
86
            $this->normalizer = new Normalizer(Normalizer::NORM_STD);
87
        }
88
89
        $this->learningRate = $learningRate;
90
        $this->maxIterations = $maxIterations;
91
    }
92
93
    public function partialTrain(array $samples, array $targets, array $labels = []): void
94
    {
95
        $this->trainByLabel($samples, $targets, $labels);
96
    }
97
98
    public function trainBinary(array $samples, array $targets, array $labels): void
99
    {
100
        if ($this->normalizer) {
101
            $this->normalizer->transform($samples);
102
        }
103
104
        // Set all target values to either -1 or 1
105
        $this->labels = [
106
            1 => $labels[0],
107
            -1 => $labels[1],
108
        ];
109
        foreach ($targets as $key => $target) {
110
            $targets[$key] = (string) $target == (string) $this->labels[1] ? 1 : -1;
111
        }
112
113
        // Set samples and feature count vars
114
        $this->featureCount = count($samples[0]);
115
116
        $this->runTraining($samples, $targets);
117
    }
118
119
    /**
120
     * Normally enabling early stopping for the optimization procedure may
121
     * help saving processing time while in some cases it may result in
122
     * premature convergence.<br>
123
     *
124
     * If "false" is given, the optimization procedure will always be executed
125
     * for $maxIterations times
126
     *
127
     * @return $this
128
     */
129
    public function setEarlyStop(bool $enable = true)
130
    {
131
        $this->enableEarlyStop = $enable;
132
133
        return $this;
134
    }
135
136
    /**
137
     * Returns the cost values obtained during the training.
138
     */
139
    public function getCostValues(): array
140
    {
141
        return $this->costValues;
142
    }
143
144
    protected function resetBinary(): void
145
    {
146
        $this->labels = [];
147
        $this->optimizer = null;
148
        $this->featureCount = 0;
149
        $this->weights = null;
150
        $this->costValues = [];
151
    }
152
153
    /**
154
     * Trains the perceptron model with Stochastic Gradient Descent optimization
155
     * to get the correct set of weights
156
     */
157
    protected function runTraining(array $samples, array $targets)
158
    {
159
        // The cost function is the sum of squares
160
        $callback = function ($weights, $sample, $target) {
161
            $this->weights = $weights;
162
163
            $prediction = $this->outputClass($sample);
164
            $gradient = $prediction - $target;
165
            $error = $gradient ** 2;
166
167
            return [$error, $gradient];
168
        };
169
170
        $this->runGradientDescent($samples, $targets, $callback);
171
    }
172
173
    /**
174
     * Executes a Gradient Descent algorithm for
175
     * the given cost function
176
     */
177
    protected function runGradientDescent(array $samples, array $targets, Closure $gradientFunc, bool $isBatch = false): void
178
    {
179
        $class = $isBatch ? GD::class : StochasticGD::class;
180
181
        if ($this->optimizer === null) {
182
            $this->optimizer = (new $class($this->featureCount))
183
                ->setLearningRate($this->learningRate)
184
                ->setMaxIterations($this->maxIterations)
185
                ->setChangeThreshold(1e-6)
186
                ->setEarlyStop($this->enableEarlyStop);
187
        }
188
189
        $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
190
        $this->costValues = $this->optimizer->getCostValues();
191
    }
192
193
    /**
194
     * Checks if the sample should be normalized and if so, returns the
195
     * normalized sample
196
     */
197
    protected function checkNormalizedSample(array $sample): array
198
    {
199
        if ($this->normalizer) {
200
            $samples = [$sample];
201
            $this->normalizer->transform($samples);
202
            $sample = $samples[0];
203
        }
204
205
        return $sample;
206
    }
207
208
    /**
209
     * Calculates net output of the network as a float value for the given input
210
     *
211
     * @return int|float
212
     */
213
    protected function output(array $sample)
214
    {
215
        $sum = 0;
216
        foreach ($this->weights as $index => $w) {
0 ignored issues
show
The expression $this->weights of type array|null is not guaranteed to be traversable. How about adding an additional type check?

There are different options of fixing this problem.

  1. If you want to be on the safe side, you can add an additional type-check:

    $collection = json_decode($data, true);
    if ( ! is_array($collection)) {
        throw new \RuntimeException('$collection must be an array.');
    }
    
    foreach ($collection as $item) { /** ... */ }
    
  2. If you are sure that the expression is traversable, you might want to add a doc comment cast to improve IDE auto-completion and static analysis:

    /** @var array $collection */
    $collection = json_decode($data, true);
    
    foreach ($collection as $item) { /** .. */ }
    
  3. Mark the issue as a false-positive: Just hover the remove button, in the top-right corner of this issue for more options.

Loading history...
217
            if ($index == 0) {
218
                $sum += $w;
219
            } else {
220
                $sum += $w * $sample[$index - 1];
221
            }
222
        }
223
224
        return $sum;
225
    }
226
227
    /**
228
     * Returns the class value (either -1 or 1) for the given input
229
     */
230
    protected function outputClass(array $sample): int
231
    {
232
        return $this->output($sample) > 0 ? 1 : -1;
233
    }
234
235
    /**
236
     * Returns the probability of the sample of belonging to the given label.
237
     *
238
     * The probability is simply taken as the distance of the sample
239
     * to the decision plane.
240
     *
241
     * @param mixed $label
242
     */
243
    protected function predictProbability(array $sample, $label): float
244
    {
245
        $predicted = $this->predictSampleBinary($sample);
246
247
        if ((string) $predicted == (string) $label) {
248
            $sample = $this->checkNormalizedSample($sample);
249
250
            return (float) abs($this->output($sample));
251
        }
252
253
        return 0.0;
254
    }
255
256
    /**
257
     * @return mixed
258
     */
259
    protected function predictSampleBinary(array $sample)
260
    {
261
        $sample = $this->checkNormalizedSample($sample);
262
263
        $predictedClass = $this->outputClass($sample);
264
265
        return $this->labels[$predictedClass];
266
    }
267
}
268