Perceptron::predictProbability()   A
last analyzed

Complexity

Conditions 2
Paths 2

Size

Total Lines 11
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 11
rs 10
c 0
b 0
f 0
cc 2
nc 2
nop 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Closure;
8
use Phpml\Classification\Classifier;
9
use Phpml\Exception\InvalidArgumentException;
10
use Phpml\Helper\OneVsRest;
11
use Phpml\Helper\Optimizer\GD;
12
use Phpml\Helper\Optimizer\Optimizer;
13
use Phpml\Helper\Optimizer\StochasticGD;
14
use Phpml\Helper\Predictable;
15
use Phpml\IncrementalEstimator;
16
use Phpml\Preprocessing\Normalizer;
17
18
class Perceptron implements Classifier, IncrementalEstimator
19
{
20
    use Predictable;
21
    use OneVsRest;
22
23
    /**
24
     * @var Optimizer|GD|StochasticGD|null
25
     */
26
    protected $optimizer;
27
28
    /**
29
     * @var array
30
     */
31
    protected $labels = [];
32
33
    /**
34
     * @var int
35
     */
36
    protected $featureCount = 0;
37
38
    /**
39
     * @var array
40
     */
41
    protected $weights = [];
42
43
    /**
44
     * @var float
45
     */
46
    protected $learningRate;
47
48
    /**
49
     * @var int
50
     */
51
    protected $maxIterations;
52
53
    /**
54
     * @var Normalizer
55
     */
56
    protected $normalizer;
57
58
    /**
59
     * @var bool
60
     */
61
    protected $enableEarlyStop = true;
62
63
    /**
64
     * Initalize a perceptron classifier with given learning rate and maximum
65
     * number of iterations used while training the perceptron
66
     *
67
     * @param float $learningRate  Value between 0.0(exclusive) and 1.0(inclusive)
68
     * @param int   $maxIterations Must be at least 1
69
     *
70
     * @throws InvalidArgumentException
71
     */
72
    public function __construct(float $learningRate = 0.001, int $maxIterations = 1000, bool $normalizeInputs = true)
73
    {
74
        if ($learningRate <= 0.0 || $learningRate > 1.0) {
75
            throw new InvalidArgumentException('Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)');
76
        }
77
78
        if ($maxIterations <= 0) {
79
            throw new InvalidArgumentException('Maximum number of iterations must be an integer greater than 0');
80
        }
81
82
        if ($normalizeInputs) {
83
            $this->normalizer = new Normalizer(Normalizer::NORM_STD);
84
        }
85
86
        $this->learningRate = $learningRate;
87
        $this->maxIterations = $maxIterations;
88
    }
89
90
    public function partialTrain(array $samples, array $targets, array $labels = []): void
91
    {
92
        $this->trainByLabel($samples, $targets, $labels);
93
    }
94
95
    public function trainBinary(array $samples, array $targets, array $labels): void
96
    {
97
        if ($this->normalizer !== null) {
98
            $this->normalizer->transform($samples);
99
        }
100
101
        // Set all target values to either -1 or 1
102
        $this->labels = [
103
            1 => $labels[0],
104
            -1 => $labels[1],
105
        ];
106
        foreach ($targets as $key => $target) {
107
            $targets[$key] = (string) $target == (string) $this->labels[1] ? 1 : -1;
108
        }
109
110
        // Set samples and feature count vars
111
        $this->featureCount = count($samples[0]);
112
113
        $this->runTraining($samples, $targets);
114
    }
115
116
    /**
117
     * Normally enabling early stopping for the optimization procedure may
118
     * help saving processing time while in some cases it may result in
119
     * premature convergence.<br>
120
     *
121
     * If "false" is given, the optimization procedure will always be executed
122
     * for $maxIterations times
123
     *
124
     * @return $this
125
     */
126
    public function setEarlyStop(bool $enable = true)
127
    {
128
        $this->enableEarlyStop = $enable;
129
130
        return $this;
131
    }
132
133
    /**
134
     * Returns the cost values obtained during the training.
135
     */
136
    public function getCostValues(): array
137
    {
138
        return $this->costValues;
139
    }
140
141
    protected function resetBinary(): void
142
    {
143
        $this->labels = [];
144
        $this->optimizer = null;
145
        $this->featureCount = 0;
146
        $this->weights = [];
147
        $this->costValues = [];
148
    }
149
150
    /**
151
     * Trains the perceptron model with Stochastic Gradient Descent optimization
152
     * to get the correct set of weights
153
     */
154
    protected function runTraining(array $samples, array $targets): void
155
    {
156
        // The cost function is the sum of squares
157
        $callback = function ($weights, $sample, $target): array {
158
            $this->weights = $weights;
159
160
            $prediction = $this->outputClass($sample);
161
            $gradient = $prediction - $target;
162
            $error = $gradient ** 2;
163
164
            return [$error, $gradient];
165
        };
166
167
        $this->runGradientDescent($samples, $targets, $callback);
168
    }
169
170
    /**
171
     * Executes a Gradient Descent algorithm for
172
     * the given cost function
173
     */
174
    protected function runGradientDescent(array $samples, array $targets, Closure $gradientFunc, bool $isBatch = false): void
175
    {
176
        $class = $isBatch ? GD::class : StochasticGD::class;
177
178
        if ($this->optimizer === null) {
179
            $this->optimizer = (new $class($this->featureCount))
180
                ->setLearningRate($this->learningRate)
181
                ->setMaxIterations($this->maxIterations)
182
                ->setChangeThreshold(1e-6)
183
                ->setEarlyStop($this->enableEarlyStop);
184
        }
185
186
        $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
187
        $this->costValues = $this->optimizer->getCostValues();
188
    }
189
190
    /**
191
     * Checks if the sample should be normalized and if so, returns the
192
     * normalized sample
193
     */
194
    protected function checkNormalizedSample(array $sample): array
195
    {
196
        if ($this->normalizer !== null) {
197
            $samples = [$sample];
198
            $this->normalizer->transform($samples);
199
            $sample = $samples[0];
200
        }
201
202
        return $sample;
203
    }
204
205
    /**
206
     * Calculates net output of the network as a float value for the given input
207
     *
208
     * @return int|float
209
     */
210
    protected function output(array $sample)
211
    {
212
        $sum = 0;
213
        foreach ($this->weights as $index => $w) {
214
            if ($index == 0) {
215
                $sum += $w;
216
            } else {
217
                $sum += $w * $sample[$index - 1];
218
            }
219
        }
220
221
        return $sum;
222
    }
223
224
    /**
225
     * Returns the class value (either -1 or 1) for the given input
226
     */
227
    protected function outputClass(array $sample): int
228
    {
229
        return $this->output($sample) > 0 ? 1 : -1;
230
    }
231
232
    /**
233
     * Returns the probability of the sample of belonging to the given label.
234
     *
235
     * The probability is simply taken as the distance of the sample
236
     * to the decision plane.
237
     *
238
     * @param mixed $label
239
     */
240
    protected function predictProbability(array $sample, $label): float
241
    {
242
        $predicted = $this->predictSampleBinary($sample);
243
244
        if ((string) $predicted == (string) $label) {
245
            $sample = $this->checkNormalizedSample($sample);
246
247
            return (float) abs($this->output($sample));
248
        }
249
250
        return 0.0;
251
    }
252
253
    /**
254
     * @return mixed
255
     */
256
    protected function predictSampleBinary(array $sample)
257
    {
258
        $sample = $this->checkNormalizedSample($sample);
259
260
        $predictedClass = $this->outputClass($sample);
261
262
        return $this->labels[$predictedClass];
263
    }
264
}
265