Passed
Pull Request — master (#287)
by Marcin
02:51
created

Perceptron::runTraining()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 15
rs 9.7666
c 0
b 0
f 0
cc 1
nc 1
nop 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Closure;
8
use Phpml\Classification\Classifier;
9
use Phpml\Exception\InvalidArgumentException;
10
use Phpml\Helper\OneVsRest;
11
use Phpml\Helper\Optimizer\GD;
12
use Phpml\Helper\Optimizer\Optimizer;
13
use Phpml\Helper\Optimizer\StochasticGD;
14
use Phpml\Helper\Predictable;
15
use Phpml\IncrementalEstimator;
16
use Phpml\Preprocessing\Normalizer;
17
18
class Perceptron implements Classifier, IncrementalEstimator
19
{
20
    use Predictable, OneVsRest;
21
22
    /**
23
     * @var Optimizer|GD|StochasticGD|null
24
     */
25
    protected $optimizer;
26
27
    /**
28
     * @var array
29
     */
30
    protected $labels = [];
31
32
    /**
33
     * @var int
34
     */
35
    protected $featureCount = 0;
36
37
    /**
38
     * @var array
39
     */
40
    protected $weights = [];
41
42
    /**
43
     * @var float
44
     */
45
    protected $learningRate;
46
47
    /**
48
     * @var int
49
     */
50
    protected $maxIterations;
51
52
    /**
53
     * @var Normalizer
54
     */
55
    protected $normalizer;
56
57
    /**
58
     * @var bool
59
     */
60
    protected $enableEarlyStop = true;
61
62
    /**
63
     * @var array
64
     */
65
    protected $costValues = [];
66
67
    /**
68
     * Initalize a perceptron classifier with given learning rate and maximum
69
     * number of iterations used while training the perceptron
70
     *
71
     * @param float $learningRate  Value between 0.0(exclusive) and 1.0(inclusive)
72
     * @param int   $maxIterations Must be at least 1
73
     *
74
     * @throws InvalidArgumentException
75
     */
76
    public function __construct(float $learningRate = 0.001, int $maxIterations = 1000, bool $normalizeInputs = true)
77
    {
78
        if ($learningRate <= 0.0 || $learningRate > 1.0) {
79
            throw new InvalidArgumentException('Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)');
80
        }
81
82
        if ($maxIterations <= 0) {
83
            throw new InvalidArgumentException('Maximum number of iterations must be an integer greater than 0');
84
        }
85
86
        if ($normalizeInputs) {
87
            $this->normalizer = new Normalizer(Normalizer::NORM_STD);
88
        }
89
90
        $this->learningRate = $learningRate;
91
        $this->maxIterations = $maxIterations;
92
    }
93
94
    public function partialTrain(array $samples, array $targets, array $labels = []): void
95
    {
96
        $this->trainByLabel($samples, $targets, $labels);
97
    }
98
99
    public function trainBinary(array $samples, array $targets, array $labels): void
100
    {
101
        if ($this->normalizer !== null) {
102
            $this->normalizer->transform($samples);
103
        }
104
105
        // Set all target values to either -1 or 1
106
        $this->labels = [
107
            1 => $labels[0],
108
            -1 => $labels[1],
109
        ];
110
        foreach ($targets as $key => $target) {
111
            $targets[$key] = (string) $target == (string) $this->labels[1] ? 1 : -1;
112
        }
113
114
        // Set samples and feature count vars
115
        $this->featureCount = count($samples[0]);
116
117
        $this->runTraining($samples, $targets);
118
    }
119
120
    /**
121
     * Normally enabling early stopping for the optimization procedure may
122
     * help saving processing time while in some cases it may result in
123
     * premature convergence.<br>
124
     *
125
     * If "false" is given, the optimization procedure will always be executed
126
     * for $maxIterations times
127
     *
128
     * @return $this
129
     */
130
    public function setEarlyStop(bool $enable = true)
131
    {
132
        $this->enableEarlyStop = $enable;
133
134
        return $this;
135
    }
136
137
    /**
138
     * Returns the cost values obtained during the training.
139
     */
140
    public function getCostValues(): array
141
    {
142
        return $this->costValues;
143
    }
144
145
    protected function resetBinary(): void
146
    {
147
        $this->labels = [];
148
        $this->optimizer = null;
149
        $this->featureCount = 0;
150
        $this->weights = [];
151
        $this->costValues = [];
152
    }
153
154
    /**
155
     * Trains the perceptron model with Stochastic Gradient Descent optimization
156
     * to get the correct set of weights
157
     */
158
    protected function runTraining(array $samples, array $targets)
159
    {
160
        // The cost function is the sum of squares
161
        $callback = function ($weights, $sample, $target) {
162
            $this->weights = $weights;
163
164
            $prediction = $this->outputClass($sample);
165
            $gradient = $prediction - $target;
166
            $error = $gradient ** 2;
167
168
            return [$error, $gradient];
169
        };
170
171
        $this->runGradientDescent($samples, $targets, $callback);
172
    }
173
174
    /**
175
     * Executes a Gradient Descent algorithm for
176
     * the given cost function
177
     */
178
    protected function runGradientDescent(array $samples, array $targets, Closure $gradientFunc, bool $isBatch = false)
179
    {
180
        $class = $isBatch ? GD::class : StochasticGD::class;
181
182
        if ($this->optimizer === null) {
183
            $this->optimizer = (new $class($this->featureCount))
184
                ->setLearningRate($this->learningRate)
185
                ->setMaxIterations($this->maxIterations)
186
                ->setChangeThreshold(1e-6)
187
                ->setEarlyStop($this->enableEarlyStop);
188
        }
189
190
        $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
191
        $this->costValues = $this->optimizer->getCostValues();
192
    }
193
194
    /**
195
     * Checks if the sample should be normalized and if so, returns the
196
     * normalized sample
197
     */
198
    protected function checkNormalizedSample(array $sample): array
199
    {
200
        if ($this->normalizer !== null) {
201
            $samples = [$sample];
202
            $this->normalizer->transform($samples);
203
            $sample = $samples[0];
204
        }
205
206
        return $sample;
207
    }
208
209
    /**
210
     * Calculates net output of the network as a float value for the given input
211
     *
212
     * @return int|float
213
     */
214
    protected function output(array $sample)
215
    {
216
        $sum = 0;
217
        foreach ($this->weights as $index => $w) {
218
            if ($index == 0) {
219
                $sum += $w;
220
            } else {
221
                $sum += $w * $sample[$index - 1];
222
            }
223
        }
224
225
        return $sum;
226
    }
227
228
    /**
229
     * Returns the class value (either -1 or 1) for the given input
230
     */
231
    protected function outputClass(array $sample): int
232
    {
233
        return $this->output($sample) > 0 ? 1 : -1;
234
    }
235
236
    /**
237
     * Returns the probability of the sample of belonging to the given label.
238
     *
239
     * The probability is simply taken as the distance of the sample
240
     * to the decision plane.
241
     *
242
     * @param mixed $label
243
     */
244
    protected function predictProbability(array $sample, $label): float
245
    {
246
        $predicted = $this->predictSampleBinary($sample);
247
248
        if ((string) $predicted == (string) $label) {
249
            $sample = $this->checkNormalizedSample($sample);
250
251
            return (float) abs($this->output($sample));
252
        }
253
254
        return 0.0;
255
    }
256
257
    /**
258
     * @return mixed
259
     */
260
    protected function predictSampleBinary(array $sample)
261
    {
262
        $sample = $this->checkNormalizedSample($sample);
263
264
        $predictedClass = $this->outputClass($sample);
265
266
        return $this->labels[$predictedClass];
267
    }
268
}
269