Passed
Pull Request — master (#63)
by
unknown
02:44
created

Perceptron::runTraining()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 15
Code Lines 8

Duplication

Lines 15
Ratio 100 %

Importance

Changes 0
Metric Value
dl 15
loc 15
rs 9.4285
c 0
b 0
f 0
cc 1
eloc 8
nc 1
nop 0
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\OneVsRest;
9
use Phpml\Helper\Optimizer\StochasticGD;
10
use Phpml\Helper\Optimizer\GD;
11
use Phpml\Classification\Classifier;
12
use Phpml\Preprocessing\Normalizer;
13
14
class Perceptron implements Classifier
15
{
16
    use Predictable, OneVsRest;
17
18
   /**
19
     * @var array
20
     */
21
    protected $samples = [];
22
23
    /**
24
     * @var array
25
     */
26
    protected $targets = [];
27
28
    /**
29
     * @var array
30
     */
31
    protected $labels = [];
32
33
    /**
34
     * @var int
35
     */
36
    protected $featureCount = 0;
37
38
    /**
39
     * @var array
40
     */
41
    protected $weights;
42
43
    /**
44
     * @var float
45
     */
46
    protected $learningRate;
47
48
    /**
49
     * @var int
50
     */
51
    protected $maxIterations;
52
53
    /**
54
     * @var Normalizer
55
     */
56
    protected $normalizer;
57
58
    /**
59
     * @var bool
60
     */
61
    protected $enableEarlyStop = true;
62
63
    /**
64
     * Initalize a perceptron classifier with given learning rate and maximum
65
     * number of iterations used while training the perceptron <br>
66
     *
67
     * Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive) <br>
68
     * Maximum number of iterations can be an integer value greater than 0
69
     * @param int $learningRate
70
     * @param int $maxIterations
71
     */
72
    public function __construct(float $learningRate = 0.001, int $maxIterations = 1000,
73
        bool $normalizeInputs = true)
74
    {
75
        if ($learningRate <= 0.0 || $learningRate > 1.0) {
76
            throw new \Exception("Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)");
77
        }
78
79
        if ($maxIterations <= 0) {
80
            throw new \Exception("Maximum number of iterations should be an integer greater than 0");
81
        }
82
83
        if ($normalizeInputs) {
84
            $this->normalizer = new Normalizer(Normalizer::NORM_STD);
85
        }
86
87
        $this->learningRate = $learningRate;
88
        $this->maxIterations = $maxIterations;
89
    }
90
91
   /**
92
     * @param array $samples
93
     * @param array $targets
94
     */
95
    public function trainBinary(array $samples, array $targets)
96
    {
97
        $this->labels = array_keys(array_count_values($targets));
98
        if (count($this->labels) > 2) {
99
            throw new \Exception("Perceptron is for binary (two-class) classification only");
100
        }
101
102
        if ($this->normalizer) {
103
            $this->normalizer->transform($samples);
104
        }
105
106
        // Set all target values to either -1 or 1
107
        $this->labels = [1 => $this->labels[0], -1 => $this->labels[1]];
108 View Code Duplication
        foreach ($targets as $target) {
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated across your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
109
            $this->targets[] = strval($target) == strval($this->labels[1]) ? 1 : -1;
110
        }
111
112
        // Set samples and feature count vars
113
        $this->samples = array_merge($this->samples, $samples);
114
        $this->featureCount = count($this->samples[0]);
115
116
        $this->runTraining();
117
    }
118
119
    /**
120
     * Normally enabling early stopping for the optimization procedure may
121
     * help saving processing time while in some cases it may result in
122
     * premature convergence.<br>
123
     *
124
     * If "false" is given, the optimization procedure will always be executed
125
     * for $maxIterations times
126
     *
127
     * @param bool $enable
128
     */
129
    public function setEarlyStop(bool $enable = true)
130
    {
131
        $this->enableEarlyStop = $enable;
132
133
        return $this;
134
    }
135
136
    /**
137
     * Trains the perceptron model with Stochastic Gradient Descent optimization
138
     * to get the correct set of weights
139
     */
140 View Code Duplication
    protected function runTraining()
0 ignored issues
show
Duplication introduced by
This method seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
141
    {
142
        // The cost function is the sum of squares
143
        $callback = function ($weights, $sample, $target) {
144
            $this->weights = $weights;
145
146
            $prediction = $this->outputClass($sample);
147
            $gradient = $prediction - $target;
148
            $error = $gradient**2;
149
150
            return [$error, $gradient];
151
        };
152
153
        $this->runGradientDescent($callback);
154
    }
155
156
    /**
157
     * Executes Stochastic Gradient Descent algorithm for
158
     * the given cost function
159
     */
160
    protected function runGradientDescent(\Closure $gradientFunc, bool $isBatch = false)
161
    {
162
        $class = $isBatch ? GD::class :  StochasticGD::class;
163
164
        $optimizer = (new $class($this->featureCount))
165
            ->setLearningRate($this->learningRate)
166
            ->setMaxIterations($this->maxIterations)
167
            ->setChangeThreshold(1e-6)
168
            ->setEarlyStop($this->enableEarlyStop);
169
170
        $this->weights = $optimizer->runOptimization($this->samples, $this->targets, $gradientFunc);
171
    }
172
173
    /**
174
     * Checks if the sample should be normalized and if so, returns the
175
     * normalized sample
176
     *
177
     * @param array $sample
178
     *
179
     * @return array
180
     */
181
    protected function checkNormalizedSample(array $sample)
182
    {
183
        if ($this->normalizer) {
184
            $samples = [$sample];
185
            $this->normalizer->transform($samples);
186
            $sample = $samples[0];
187
        }
188
189
        return $sample;
190
    }
191
192
    /**
193
     * Calculates net output of the network as a float value for the given input
194
     *
195
     * @param array $sample
196
     * @return int
197
     */
198
    protected function output(array $sample)
199
    {
200
        $sum = 0;
201
        foreach ($this->weights as $index => $w) {
202
            if ($index == 0) {
203
                $sum += $w;
204
            } else {
205
                $sum += $w * $sample[$index - 1];
206
            }
207
        }
208
209
        return $sum;
210
    }
211
212
    /**
213
     * Returns the class value (either -1 or 1) for the given input
214
     *
215
     * @param array $sample
216
     * @return int
217
     */
218
    protected function outputClass(array $sample)
219
    {
220
        return $this->output($sample) > 0 ? 1 : -1;
221
    }
222
223
    /**
224
     * Returns the probability of the sample of belonging to the given label.
225
     *
226
     * The probability is simply taken as the distance of the sample
227
     * to the decision plane.
228
     *
229
     * @param array $sample
230
     * @param mixed $label
231
     */
232 View Code Duplication
    protected function predictProbability(array $sample, $label)
0 ignored issues
show
Duplication introduced by
This method seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
233
    {
234
        $predicted = $this->predictSampleBinary($sample);
235
236
        if (strval($predicted) == strval($label)) {
237
            $sample = $this->checkNormalizedSample($sample);
238
            return abs($this->output($sample));
239
        }
240
241
        return 0.0;
242
    }
243
244
    /**
245
     * @param array $sample
246
     * @return mixed
247
     */
248
    protected function predictSampleBinary(array $sample)
249
    {
250
        $sample = $this->checkNormalizedSample($sample);
251
252
        $predictedClass = $this->outputClass($sample);
253
254
        return $this->labels[ $predictedClass ];
255
    }
256
}
257