Completed
Push — master ( f0a798...cf222b )
by Arkadiusz
02:58
created

Perceptron::train()   B

Complexity

Conditions 5
Paths 7

Size

Total Lines 25
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 25
rs 8.439
c 0
b 0
f 0
cc 5
eloc 13
nc 7
nop 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification\Linear;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Helper\Trainable;
9
use Phpml\Classification\Classifier;
10
11
class Perceptron implements Classifier
12
{
13
    use Predictable;
14
15
    /**
16
     * The function whose result will be used to calculate the network error
17
     * for each instance
18
     *
19
     * @var string
20
     */
21
    protected static $errorFunction = 'outputClass';
22
23
   /**
24
     * @var array
25
     */
26
    protected $samples = [];
27
28
    /**
29
     * @var array
30
     */
31
    protected $targets = [];
32
33
    /**
34
     * @var array
35
     */
36
    protected $labels = [];
37
38
    /**
39
     * @var int
40
     */
41
    protected $featureCount = 0;
42
43
    /**
44
     * @var array
45
     */
46
    protected $weights;
47
48
    /**
49
     * @var float
50
     */
51
    protected $learningRate;
52
53
    /**
54
     * @var int
55
     */
56
    protected $maxIterations;
57
58
    /**
59
     * Initalize a perceptron classifier with given learning rate and maximum
60
     * number of iterations used while training the perceptron <br>
61
     *
62
     * Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive) <br>
63
     * Maximum number of iterations can be an integer value greater than 0
64
     * @param int $learningRate
65
     * @param int $maxIterations
66
     */
67
    public function __construct(float $learningRate = 0.001, int $maxIterations = 1000)
68
    {
69
        if ($learningRate <= 0.0 || $learningRate > 1.0) {
70
            throw new \Exception("Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)");
71
        }
72
73
        if ($maxIterations <= 0) {
74
            throw new \Exception("Maximum number of iterations should be an integer greater than 0");
75
        }
76
77
        $this->learningRate = $learningRate;
78
        $this->maxIterations = $maxIterations;
79
    }
80
81
   /**
82
     * @param array $samples
83
     * @param array $targets
84
     */
85
    public function train(array $samples, array $targets)
86
    {
87
        $this->labels = array_keys(array_count_values($targets));
88
        if (count($this->labels) > 2) {
89
            throw new \Exception("Perceptron is for only binary (two-class) classification");
90
        }
91
92
        // Set all target values to either -1 or 1
93
        $this->labels = [1 => $this->labels[0], -1 => $this->labels[1]];
94
        foreach ($targets as $target) {
95
            $this->targets[] = $target == $this->labels[1] ? 1 : -1;
96
        }
97
98
        // Set samples and feature count vars
99
        $this->samples = array_merge($this->samples, $samples);
100
        $this->featureCount = count($this->samples[0]);
101
102
        // Init weights with random values
103
        $this->weights = array_fill(0, $this->featureCount + 1, 0);
104
        foreach ($this->weights as &$weight) {
105
            $weight = rand() / (float) getrandmax();
106
        }
107
        // Do training
108
        $this->runTraining();
109
    }
110
111
    /**
112
     * Adapts the weights with respect to given samples and targets
113
     * by use of perceptron learning rule
114
     */
115
    protected function runTraining()
116
    {
117
        $currIter = 0;
118
        while ($this->maxIterations > $currIter++) {
119
            foreach ($this->samples as $index => $sample) {
120
                $target = $this->targets[$index];
121
                $prediction = $this->{static::$errorFunction}($sample);
122
                $update = $target - $prediction;
123
                // Update bias
124
                $this->weights[0] += $update * $this->learningRate; // Bias
125
                // Update other weights
126
                for ($i=1; $i <= $this->featureCount; $i++) {
127
                    $this->weights[$i] += $update * $sample[$i - 1] * $this->learningRate;
128
                }
129
            }
130
        }
131
    }
132
133
    /**
134
     * Calculates net output of the network as a float value for the given input
135
     *
136
     * @param array $sample
137
     * @return int
138
     */
139
    protected function output(array $sample)
140
    {
141
        $sum = 0;
142
        foreach ($this->weights as $index => $w) {
143
            if ($index == 0) {
144
                $sum += $w;
145
            } else {
146
                $sum += $w * $sample[$index - 1];
147
            }
148
        }
149
150
        return $sum;
151
    }
152
153
    /**
154
     * Returns the class value (either -1 or 1) for the given input
155
     *
156
     * @param array $sample
157
     * @return int
158
     */
159
    protected function outputClass(array $sample)
160
    {
161
        return $this->output($sample) > 0 ? 1 : -1;
162
    }
163
164
    /**
165
     * @param array $sample
166
     * @return mixed
167
     */
168
    protected function predictSample(array $sample)
169
    {
170
        $predictedClass = $this->outputClass($sample);
171
172
        return $this->labels[ $predictedClass ];
173
    }
174
}
175