Passed
Pull Request — master (#315)
by Marcin
02:40
created

MultilayerPerceptron::initNetwork()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 14
rs 9.7998
c 0
b 0
f 0
cc 1
nc 1
nop 0
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\NeuralNetwork\Network;
6
7
use Phpml\Estimator;
8
use Phpml\Exception\InvalidArgumentException;
9
use Phpml\Helper\Predictable;
10
use Phpml\IncrementalEstimator;
11
use Phpml\NeuralNetwork\ActivationFunction;
12
use Phpml\NeuralNetwork\ActivationFunction\Sigmoid;
13
use Phpml\NeuralNetwork\Layer;
14
use Phpml\NeuralNetwork\Node\Bias;
15
use Phpml\NeuralNetwork\Node\Input;
16
use Phpml\NeuralNetwork\Node\Neuron;
17
use Phpml\NeuralNetwork\Node\Neuron\Synapse;
18
use Phpml\NeuralNetwork\Training\Backpropagation;
19
20
abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator, IncrementalEstimator
21
{
22
    use Predictable;
23
24
    /**
25
     * @var array
26
     */
27
    protected $classes = [];
28
29
    /**
30
     * @var ActivationFunction|null
31
     */
32
    protected $activationFunction;
33
34
    /**
35
     * @var Backpropagation
36
     */
37
    protected $backpropagation;
38
39
    /**
40
     * @var int
41
     */
42
    private $inputLayerFeatures;
43
44
    /**
45
     * @var array
46
     */
47
    private $hiddenLayers = [];
48
49
    /**
50
     * @var float
51
     */
52
    private $learningRate;
53
54
    /**
55
     * @var int
56
     */
57
    private $iterations;
58
59
    /**
60
     * @throws InvalidArgumentException
61
     */
62
    public function __construct(int $inputLayerFeatures, array $hiddenLayers, array $classes, int $iterations = 10000, ?ActivationFunction $activationFunction = null, float $learningRate = 1)
63
    {
64
        if (empty($hiddenLayers)) {
65
            throw new InvalidArgumentException('Provide at least 1 hidden layer');
66
        }
67
68
        if (count($classes) < 2) {
69
            throw new InvalidArgumentException('Provide at least 2 different classes');
70
        }
71
72
        if (count($classes) !== count(array_unique($classes))) {
73
            throw new InvalidArgumentException('Classes must be unique');
74
        }
75
76
        $this->classes = array_values($classes);
77
        $this->iterations = $iterations;
78
        $this->inputLayerFeatures = $inputLayerFeatures;
79
        $this->hiddenLayers = $hiddenLayers;
80
        $this->activationFunction = $activationFunction;
81
        $this->learningRate = $learningRate;
0 ignored issues
show
Documentation Bug introduced by
It seems like $learningRate can also be of type integer. However, the property $learningRate is declared as type double. Maybe add an additional type check?

Our type inference engine has found a suspicous assignment of a value to a property. This check raises an issue when a value that can be of a mixed type is assigned to a property that is type hinted more strictly.

For example, imagine you have a variable $accountId that can either hold an Id object or false (if there is no account id yet). Your code now assigns that value to the id property of an instance of the Account class. This class holds a proper account, so the id value must no longer be false.

Either this assignment is in error or a type check should be added for that assignment.

class Id
{
    public $id;

    public function __construct($id)
    {
        $this->id = $id;
    }

}

class Account
{
    /** @var  Id $id */
    public $id;
}

$account_id = false;

if (starsAreRight()) {
    $account_id = new Id(42);
}

$account = new Account();
if ($account instanceof Id)
{
    $account->id = $account_id;
}
Loading history...
82
83
        $this->initNetwork();
84
    }
85
86
    public function train(array $samples, array $targets): void
87
    {
88
        $this->reset();
89
        $this->initNetwork();
90
        $this->partialTrain($samples, $targets, $this->classes);
91
    }
92
93
    /**
94
     * @throws InvalidArgumentException
95
     */
96
    public function partialTrain(array $samples, array $targets, array $classes = []): void
97
    {
98
        if (!empty($classes) && array_values($classes) !== $this->classes) {
99
            // We require the list of classes in the constructor.
100
            throw new InvalidArgumentException(
101
                'The provided classes don\'t match the classes provided in the constructor'
102
            );
103
        }
104
105
        for ($i = 0; $i < $this->iterations; ++$i) {
106
            $this->trainSamples($samples, $targets);
107
        }
108
    }
109
110
    public function setLearningRate(float $learningRate): void
111
    {
112
        $this->learningRate = $learningRate;
113
        $this->backpropagation->setLearningRate($this->learningRate);
114
    }
115
116
    public function getOutput(): array
117
    {
118
        $result = [];
119
        foreach ($this->getOutputLayer()->getNodes() as $i => $neuron) {
120
            $result[$this->classes[$i]] = $neuron->getOutput();
121
        }
122
123
        return $result;
124
    }
125
126
    /**
127
     * @param mixed $target
128
     */
129
    abstract protected function trainSample(array $sample, $target);
130
131
    /**
132
     * @return mixed
133
     */
134
    abstract protected function predictSample(array $sample);
135
136
    protected function reset(): void
137
    {
138
        $this->removeLayers();
139
    }
140
141
    private function initNetwork(): void
142
    {
143
        $this->addInputLayer($this->inputLayerFeatures);
144
        $this->addNeuronLayers($this->hiddenLayers, $this->activationFunction);
145
146
        // Sigmoid function for the output layer as we want a value from 0 to 1.
147
        $sigmoid = new Sigmoid();
148
        $this->addNeuronLayers([count($this->classes)], $sigmoid);
149
150
        $this->addBiasNodes();
151
        $this->generateSynapses();
152
153
        $this->backpropagation = new Backpropagation($this->learningRate);
154
    }
155
156
    private function addInputLayer(int $nodes): void
157
    {
158
        $this->addLayer(new Layer($nodes, Input::class));
159
    }
160
161
    private function addNeuronLayers(array $layers, ?ActivationFunction $defaultActivationFunction = null): void
162
    {
163
        foreach ($layers as $layer) {
164
            if (is_array($layer)) {
165
                $function = $layer[1] instanceof ActivationFunction ? $layer[1] : $defaultActivationFunction;
166
                $this->addLayer(new Layer($layer[0], Neuron::class, $function));
167
            } elseif ($layer instanceof Layer) {
168
                $this->addLayer($layer);
169
            } else {
170
                $this->addLayer(new Layer($layer, Neuron::class, $defaultActivationFunction));
171
            }
172
        }
173
    }
174
175
    private function generateSynapses(): void
176
    {
177
        $layersNumber = count($this->layers) - 1;
178
        for ($i = 0; $i < $layersNumber; ++$i) {
179
            $currentLayer = $this->layers[$i];
180
            $nextLayer = $this->layers[$i + 1];
181
            $this->generateLayerSynapses($nextLayer, $currentLayer);
182
        }
183
    }
184
185
    private function addBiasNodes(): void
186
    {
187
        $biasLayers = count($this->layers) - 1;
188
        for ($i = 0; $i < $biasLayers; ++$i) {
189
            $this->layers[$i]->addNode(new Bias());
190
        }
191
    }
192
193
    private function generateLayerSynapses(Layer $nextLayer, Layer $currentLayer): void
194
    {
195
        foreach ($nextLayer->getNodes() as $nextNeuron) {
196
            if ($nextNeuron instanceof Neuron) {
197
                $this->generateNeuronSynapses($currentLayer, $nextNeuron);
198
            }
199
        }
200
    }
201
202
    private function generateNeuronSynapses(Layer $currentLayer, Neuron $nextNeuron): void
203
    {
204
        foreach ($currentLayer->getNodes() as $currentNeuron) {
205
            $nextNeuron->addSynapse(new Synapse($currentNeuron));
206
        }
207
    }
208
209
    private function trainSamples(array $samples, array $targets): void
210
    {
211
        foreach ($targets as $key => $target) {
212
            $this->trainSample($samples[$key], $target);
213
        }
214
    }
215
}
216