Completed
Push — master ( cf222b...4daa0a )
by Arkadiusz
03:24
created

Perceptron::__construct()   B

Complexity

Conditions 5
Paths 4

Size

Total Lines 18
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 18
rs 8.8571
c 0
b 0
f 0
cc 5
eloc 10
nc 4
nop 3
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\Trainable;
9
use Phpml\Classification\Classifier;
10
use Phpml\Preprocessing\Normalizer;
11
12
class Perceptron implements Classifier
13
{
14
    use Predictable;
15
16
    /**
17
     * The function whose result will be used to calculate the network error
18
     * for each instance
19
     *
20
     * @var string
21
     */
22
    protected static $errorFunction = 'outputClass';
23
24
   /**
25
     * @var array
26
     */
27
    protected $samples = [];
28
29
    /**
30
     * @var array
31
     */
32
    protected $targets = [];
33
34
    /**
35
     * @var array
36
     */
37
    protected $labels = [];
38
39
    /**
40
     * @var int
41
     */
42
    protected $featureCount = 0;
43
44
    /**
45
     * @var array
46
     */
47
    protected $weights;
48
49
    /**
50
     * @var float
51
     */
52
    protected $learningRate;
53
54
    /**
55
     * @var int
56
     */
57
    protected $maxIterations;
58
59
    /**
60
     * @var Normalizer
61
     */
62
    protected $normalizer;
63
64
    /**
65
     * Initalize a perceptron classifier with given learning rate and maximum
66
     * number of iterations used while training the perceptron <br>
67
     *
68
     * Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive) <br>
69
     * Maximum number of iterations can be an integer value greater than 0
70
     * @param int $learningRate
71
     * @param int $maxIterations
72
     */
73
    public function __construct(float $learningRate = 0.001, int $maxIterations = 1000,
74
        bool $normalizeInputs = true)
75
    {
76
        if ($learningRate <= 0.0 || $learningRate > 1.0) {
77
            throw new \Exception("Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)");
78
        }
79
80
        if ($maxIterations <= 0) {
81
            throw new \Exception("Maximum number of iterations should be an integer greater than 0");
82
        }
83
84
        if ($normalizeInputs) {
85
            $this->normalizer = new Normalizer(Normalizer::NORM_STD);
86
        }
87
88
        $this->learningRate = $learningRate;
89
        $this->maxIterations = $maxIterations;
90
    }
91
92
   /**
93
     * @param array $samples
94
     * @param array $targets
95
     */
96
    public function train(array $samples, array $targets)
97
    {
98
        $this->labels = array_keys(array_count_values($targets));
99
        if (count($this->labels) > 2) {
100
            throw new \Exception("Perceptron is for only binary (two-class) classification");
101
        }
102
103
        if ($this->normalizer) {
104
            $this->normalizer->transform($samples);
105
        }
106
107
        // Set all target values to either -1 or 1
108
        $this->labels = [1 => $this->labels[0], -1 => $this->labels[1]];
109 View Code Duplication
        foreach ($targets as $target) {
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated across your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
110
            $this->targets[] = $target == $this->labels[1] ? 1 : -1;
111
        }
112
113
        // Set samples and feature count vars
114
        $this->samples = array_merge($this->samples, $samples);
115
        $this->featureCount = count($this->samples[0]);
116
117
        // Init weights with random values
118
        $this->weights = array_fill(0, $this->featureCount + 1, 0);
119
        foreach ($this->weights as &$weight) {
120
            $weight = rand() / (float) getrandmax();
121
        }
122
        // Do training
123
        $this->runTraining();
124
    }
125
126
    /**
127
     * Adapts the weights with respect to given samples and targets
128
     * by use of perceptron learning rule
129
     */
130
    protected function runTraining()
131
    {
132
        $currIter = 0;
133
        while ($this->maxIterations > $currIter++) {
134
            foreach ($this->samples as $index => $sample) {
135
                $target = $this->targets[$index];
136
                $prediction = $this->{static::$errorFunction}($sample);
137
                $update = $target - $prediction;
138
                // Update bias
139
                $this->weights[0] += $update * $this->learningRate; // Bias
140
                // Update other weights
141
                for ($i=1; $i <= $this->featureCount; $i++) {
142
                    $this->weights[$i] += $update * $sample[$i - 1] * $this->learningRate;
143
                }
144
            }
145
        }
146
    }
147
148
    /**
149
     * Calculates net output of the network as a float value for the given input
150
     *
151
     * @param array $sample
152
     * @return int
153
     */
154
    protected function output(array $sample)
155
    {
156
        $sum = 0;
157
        foreach ($this->weights as $index => $w) {
158
            if ($index == 0) {
159
                $sum += $w;
160
            } else {
161
                $sum += $w * $sample[$index - 1];
162
            }
163
        }
164
165
        return $sum;
166
    }
167
168
    /**
169
     * Returns the class value (either -1 or 1) for the given input
170
     *
171
     * @param array $sample
172
     * @return int
173
     */
174
    protected function outputClass(array $sample)
175
    {
176
        return $this->output($sample) > 0 ? 1 : -1;
177
    }
178
179
    /**
180
     * @param array $sample
181
     * @return mixed
182
     */
183
    protected function predictSample(array $sample)
184
    {
185
        if ($this->normalizer) {
186
            $samples = [$sample];
187
            $this->normalizer->transform($samples);
188
            $sample = $samples[0];
189
        }
190
191
        $predictedClass = $this->outputClass($sample);
192
193
        return $this->labels[ $predictedClass ];
194
    }
195
}
196