Passed
Push — master ( c44f3b...492344 )
by Arkadiusz
02:48
created

Perceptron::runTraining()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 15
Code Lines 8

Duplication

Lines 15
Ratio 100 %

Importance

Changes 0
Metric Value
dl 15
loc 15
rs 9.4285
c 0
b 0
f 0
cc 1
eloc 8
nc 1
nop 0
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\OneVsRest;
9
use Phpml\Helper\Optimizer\StochasticGD;
10
use Phpml\Helper\Optimizer\GD;
11
use Phpml\Classification\Classifier;
12
use Phpml\Preprocessing\Normalizer;
13
14
class Perceptron implements Classifier
15
{
16
    use Predictable, OneVsRest;
17
18
   /**
19
     * @var array
20
     */
21
    protected $samples = [];
22
23
    /**
24
     * @var array
25
     */
26
    protected $targets = [];
27
28
    /**
29
     * @var array
30
     */
31
    protected $labels = [];
32
33
    /**
34
     * @var int
35
     */
36
    protected $featureCount = 0;
37
38
    /**
39
     * @var array
40
     */
41
    protected $weights;
42
43
    /**
44
     * @var float
45
     */
46
    protected $learningRate;
47
48
    /**
49
     * @var int
50
     */
51
    protected $maxIterations;
52
53
    /**
54
     * @var Normalizer
55
     */
56
    protected $normalizer;
57
58
    /**
59
     * @var bool
60
     */
61
    protected $enableEarlyStop = true;
62
63
    /**
64
     * @var array
65
     */
66
    protected $costValues = [];
67
68
    /**
69
     * Initalize a perceptron classifier with given learning rate and maximum
70
     * number of iterations used while training the perceptron <br>
71
     *
72
     * Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive) <br>
73
     * Maximum number of iterations can be an integer value greater than 0
74
     * @param int $learningRate
75
     * @param int $maxIterations
76
     */
77
    public function __construct(float $learningRate = 0.001, int $maxIterations = 1000,
78
        bool $normalizeInputs = true)
79
    {
80
        if ($learningRate <= 0.0 || $learningRate > 1.0) {
81
            throw new \Exception("Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)");
82
        }
83
84
        if ($maxIterations <= 0) {
85
            throw new \Exception("Maximum number of iterations should be an integer greater than 0");
86
        }
87
88
        if ($normalizeInputs) {
89
            $this->normalizer = new Normalizer(Normalizer::NORM_STD);
90
        }
91
92
        $this->learningRate = $learningRate;
93
        $this->maxIterations = $maxIterations;
94
    }
95
96
   /**
97
     * @param array $samples
98
     * @param array $targets
99
     */
100
    public function trainBinary(array $samples, array $targets)
101
    {
102
        $this->labels = array_keys(array_count_values($targets));
103
        if (count($this->labels) > 2) {
104
            throw new \Exception("Perceptron is for binary (two-class) classification only");
105
        }
106
107
        if ($this->normalizer) {
108
            $this->normalizer->transform($samples);
109
        }
110
111
        // Set all target values to either -1 or 1
112
        $this->labels = [1 => $this->labels[0], -1 => $this->labels[1]];
113 View Code Duplication
        foreach ($targets as $target) {
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated across your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
114
            $this->targets[] = strval($target) == strval($this->labels[1]) ? 1 : -1;
115
        }
116
117
        // Set samples and feature count vars
118
        $this->samples = array_merge($this->samples, $samples);
119
        $this->featureCount = count($this->samples[0]);
120
121
        $this->runTraining();
122
    }
123
124
    /**
125
     * Normally enabling early stopping for the optimization procedure may
126
     * help saving processing time while in some cases it may result in
127
     * premature convergence.<br>
128
     *
129
     * If "false" is given, the optimization procedure will always be executed
130
     * for $maxIterations times
131
     *
132
     * @param bool $enable
133
     */
134
    public function setEarlyStop(bool $enable = true)
135
    {
136
        $this->enableEarlyStop = $enable;
137
138
        return $this;
139
    }
140
141
    /**
142
     * Returns the cost values obtained during the training.
143
     *
144
     * @return array
145
     */
146
    public function getCostValues()
147
    {
148
        return $this->costValues;
149
    }
150
151
    /**
152
     * Trains the perceptron model with Stochastic Gradient Descent optimization
153
     * to get the correct set of weights
154
     */
155 View Code Duplication
    protected function runTraining()
0 ignored issues
show
Duplication introduced by
This method seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
156
    {
157
        // The cost function is the sum of squares
158
        $callback = function ($weights, $sample, $target) {
159
            $this->weights = $weights;
160
161
            $prediction = $this->outputClass($sample);
162
            $gradient = $prediction - $target;
163
            $error = $gradient**2;
164
165
            return [$error, $gradient];
166
        };
167
168
        $this->runGradientDescent($callback);
169
    }
170
171
    /**
172
     * Executes Stochastic Gradient Descent algorithm for
173
     * the given cost function
174
     */
175
    protected function runGradientDescent(\Closure $gradientFunc, bool $isBatch = false)
176
    {
177
        $class = $isBatch ? GD::class :  StochasticGD::class;
178
179
        $optimizer = (new $class($this->featureCount))
180
            ->setLearningRate($this->learningRate)
181
            ->setMaxIterations($this->maxIterations)
182
            ->setChangeThreshold(1e-6)
183
            ->setEarlyStop($this->enableEarlyStop);
184
185
        $this->weights = $optimizer->runOptimization($this->samples, $this->targets, $gradientFunc);
186
        $this->costValues = $optimizer->getCostValues();
187
    }
188
189
    /**
190
     * Checks if the sample should be normalized and if so, returns the
191
     * normalized sample
192
     *
193
     * @param array $sample
194
     *
195
     * @return array
196
     */
197
    protected function checkNormalizedSample(array $sample)
198
    {
199
        if ($this->normalizer) {
200
            $samples = [$sample];
201
            $this->normalizer->transform($samples);
202
            $sample = $samples[0];
203
        }
204
205
        return $sample;
206
    }
207
208
    /**
209
     * Calculates net output of the network as a float value for the given input
210
     *
211
     * @param array $sample
212
     * @return int
213
     */
214
    protected function output(array $sample)
215
    {
216
        $sum = 0;
217
        foreach ($this->weights as $index => $w) {
218
            if ($index == 0) {
219
                $sum += $w;
220
            } else {
221
                $sum += $w * $sample[$index - 1];
222
            }
223
        }
224
225
        return $sum;
226
    }
227
228
    /**
229
     * Returns the class value (either -1 or 1) for the given input
230
     *
231
     * @param array $sample
232
     * @return int
233
     */
234
    protected function outputClass(array $sample)
235
    {
236
        return $this->output($sample) > 0 ? 1 : -1;
237
    }
238
239
    /**
240
     * Returns the probability of the sample of belonging to the given label.
241
     *
242
     * The probability is simply taken as the distance of the sample
243
     * to the decision plane.
244
     *
245
     * @param array $sample
246
     * @param mixed $label
247
     */
248 View Code Duplication
    protected function predictProbability(array $sample, $label)
0 ignored issues
show
Duplication introduced by
This method seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
249
    {
250
        $predicted = $this->predictSampleBinary($sample);
251
252
        if (strval($predicted) == strval($label)) {
253
            $sample = $this->checkNormalizedSample($sample);
254
            return abs($this->output($sample));
255
        }
256
257
        return 0.0;
258
    }
259
260
    /**
261
     * @param array $sample
262
     * @return mixed
263
     */
264
    protected function predictSampleBinary(array $sample)
265
    {
266
        $sample = $this->checkNormalizedSample($sample);
267
268
        $predictedClass = $this->outputClass($sample);
269
270
        return $this->labels[ $predictedClass ];
271
    }
272
}
273