Test Failed
Pull Request — master (#63)
by
unknown
02:47
created

Perceptron::runTraining()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 15
Code Lines 8

Duplication

Lines 15
Ratio 100 %

Importance

Changes 0
Metric Value
dl 15
loc 15
rs 9.4285
c 0
b 0
f 0
cc 1
eloc 8
nc 1
nop 0
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\OneVsRest;
9
use Phpml\Helper\Optimizer\StochasticGD;
10
use Phpml\Helper\Optimizer\GD;
11
use Phpml\Classification\Classifier;
12
use Phpml\Preprocessing\Normalizer;
13
14
class Perceptron implements Classifier
15
{
16
    use Predictable, OneVsRest;
17
18
   /**
19
     * @var array
20
     */
21
    protected $samples = [];
22
23
    /**
24
     * @var array
25
     */
26
    protected $targets = [];
27
28
    /**
29
     * @var array
30
     */
31
    protected $labels = [];
32
33
    /**
34
     * @var int
35
     */
36
    protected $featureCount = 0;
37
38
    /**
39
     * @var array
40
     */
41
    protected $weights;
42
43
    /**
44
     * @var float
45
     */
46
    protected $learningRate;
47
48
    /**
49
     * @var int
50
     */
51
    protected $maxIterations;
52
53
    /**
54
     * @var Normalizer
55
     */
56
    protected $normalizer;
57
58
    /**
59
     * Initalize a perceptron classifier with given learning rate and maximum
60
     * number of iterations used while training the perceptron <br>
61
     *
62
     * Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive) <br>
63
     * Maximum number of iterations can be an integer value greater than 0
64
     * @param int $learningRate
65
     * @param int $maxIterations
66
     */
67
    public function __construct(float $learningRate = 0.001, int $maxIterations = 1000,
68
        bool $normalizeInputs = true)
69
    {
70
        if ($learningRate <= 0.0 || $learningRate > 1.0) {
71
            throw new \Exception("Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)");
72
        }
73
74
        if ($maxIterations <= 0) {
75
            throw new \Exception("Maximum number of iterations should be an integer greater than 0");
76
        }
77
78
        if ($normalizeInputs) {
79
            $this->normalizer = new Normalizer(Normalizer::NORM_STD);
80
        }
81
82
        $this->learningRate = $learningRate;
83
        $this->maxIterations = $maxIterations;
84
    }
85
86
   /**
87
     * @param array $samples
88
     * @param array $targets
89
     */
90
    public function trainBinary(array $samples, array $targets)
91
    {
92
        $this->labels = array_keys(array_count_values($targets));
93
        if (count($this->labels) > 2) {
94
            throw new \Exception("Perceptron is for binary (two-class) classification only");
95
        }
96
97
        if ($this->normalizer) {
98
            $this->normalizer->transform($samples);
99
        }
100
101
        // Set all target values to either -1 or 1
102
        $this->labels = [1 => $this->labels[0], -1 => $this->labels[1]];
103 View Code Duplication
        foreach ($targets as $target) {
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated across your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
104
            $this->targets[] = strval($target) == strval($this->labels[1]) ? 1 : -1;
105
        }
106
107
        // Set samples and feature count vars
108
        $this->samples = array_merge($this->samples, $samples);
109
        $this->featureCount = count($this->samples[0]);
110
111
        $this->runTraining();
112
    }
113
114
    /**
115
     * Trains the perceptron model with Stochastic Gradient Descent optimization
116
     * to get the correct set of weights
117
     */
118 View Code Duplication
    protected function runTraining()
0 ignored issues
show
Duplication introduced by
This method seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
119
    {
120
        // The cost function is the sum of squares
121
        $callback = function ($weights, $sample, $target) {
122
            $this->weights = $weights;
123
124
            $prediction = $this->outputClass($sample);
125
            $gradient = $prediction - $target;
126
            $error = $gradient**2;
127
128
            return [$error, $gradient];
129
        };
130
131
        $this->runGradientDescent($callback);
132
    }
133
134
    /**
135
     * Executes Stochastic Gradient Descent algorithm for
136
     * the given cost function
137
     */
138
    protected function runGradientDescent(\Closure $gradientFunc, bool $isBatch = false)
139
    {
140
        $class = $isBatch ? GD::class :  StochasticGD::class;
141
142
        $optimizer = (new $class($this->featureCount))
143
            ->setLearningRate($this->learningRate)
144
            ->setMaxIterations($this->maxIterations);
145
146
        $this->weights = $optimizer->runOptimization($this->samples, $this->targets, $gradientFunc);
147
    }
148
149
    /**
150
     * Checks if the sample should be normalized and if so, returns the
151
     * normalized sample
152
     *
153
     * @param array $sample
154
     *
155
     * @return array
156
     */
157
    protected function checkNormalizedSample(array $sample)
158
    {
159
        if ($this->normalizer) {
160
            $samples = [$sample];
161
            $this->normalizer->transform($samples);
162
            $sample = $samples[0];
163
        }
164
165
        return $sample;
166
    }
167
168
    /**
169
     * Calculates net output of the network as a float value for the given input
170
     *
171
     * @param array $sample
172
     * @return int
173
     */
174
    protected function output(array $sample)
175
    {
176
        $sum = 0;
177
        foreach ($this->weights as $index => $w) {
178
            if ($index == 0) {
179
                $sum += $w;
180
            } else {
181
                $sum += $w * $sample[$index - 1];
182
            }
183
        }
184
185
        return $sum;
186
    }
187
188
    /**
189
     * Returns the class value (either -1 or 1) for the given input
190
     *
191
     * @param array $sample
192
     * @return int
193
     */
194
    protected function outputClass(array $sample)
195
    {
196
        return $this->output($sample) > 0 ? 1 : -1;
197
    }
198
199
    /**
200
     * Returns the probability of the sample of belonging to the given label.
201
     *
202
     * The probability is simply taken as the distance of the sample
203
     * to the decision plane.
204
     *
205
     * @param array $sample
206
     * @param mixed $label
207
     */
208 View Code Duplication
    protected function predictProbability(array $sample, $label)
0 ignored issues
show
Duplication introduced by
This method seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
209
    {
210
        $predicted = $this->predictSampleBinary($sample);
211
212
        if (strval($predicted) == strval($label)) {
213
            $sample = $this->checkNormalizedSample($sample);
214
            return abs($this->output($sample));
215
        }
216
217
        return 0.0;
218
    }
219
220
    /**
221
     * @param array $sample
222
     * @return mixed
223
     */
224
    protected function predictSampleBinary(array $sample)
225
    {
226
        $sample = $this->checkNormalizedSample($sample);
227
228
        $predictedClass = $this->outputClass($sample);
229
230
        return $this->labels[ $predictedClass ];
231
    }
232
}
233