Passed
Pull Request — master (#399)
by
unknown
05:27
created

MultilayerPerceptron::__construct()   A

Complexity

Conditions 4
Paths 4

Size

Total Lines 28
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 28
rs 9.8333
c 0
b 0
f 0
cc 4
nc 4
nop 6
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\NeuralNetwork\Network;
6
7
use Phpml\Estimator;
8
use Phpml\Exception\InvalidArgumentException;
9
use Phpml\Helper\Predictable;
10
use Phpml\IncrementalEstimator;
11
use Phpml\NeuralNetwork\ActivationFunction;
12
use Phpml\NeuralNetwork\ActivationFunction\Sigmoid;
13
use Phpml\NeuralNetwork\Layer;
14
use Phpml\NeuralNetwork\Node\Bias;
15
use Phpml\NeuralNetwork\Node\Input;
16
use Phpml\NeuralNetwork\Node\Neuron;
17
use Phpml\NeuralNetwork\Node\Neuron\Synapse;
18
use Phpml\NeuralNetwork\Training\Backpropagation;
19
20
abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator, IncrementalEstimator
21
{
22
    use Predictable;
23
24
    /**
25
     * @var array
26
     */
27
    protected $classes = [];
28
29
    /**
30
     * @var ActivationFunction|null
31
     */
32
    protected $activationFunction;
33
34
    /**
35
     * @var Backpropagation
36
     */
37
    protected $backpropagation;
38
39
    /**
40
     * @var int
41
     */
42
    private $inputLayerFeatures;
43
44
    /**
45
     * @var array
46
     */
47
    private $hiddenLayers = [];
48
49
    /**
50
     * @var float
51
     */
52
    private $learningRate;
53
54
    /**
55
     * @var int
56
     */
57
    private $iterations;
58
59
    /**
60
     * @throws InvalidArgumentException
61
     */
62
    public function __construct(
63
        int $inputLayerFeatures,
64
        array $hiddenLayers,
65
        array $classes,
66
        int $iterations = 10000,
67
        ?ActivationFunction $activationFunction = null,
68
        float $learningRate = 1.
69
    ) {
70
        if (count($hiddenLayers) === 0) {
71
            throw new InvalidArgumentException('Provide at least 1 hidden layer');
72
        }
73
74
        if (count($classes) < 2) {
75
            throw new InvalidArgumentException('Provide at least 2 different classes');
76
        }
77
78
        if (count($classes) !== count(array_unique($classes))) {
79
            throw new InvalidArgumentException('Classes must be unique');
80
        }
81
82
        $this->classes = array_values($classes);
83
        $this->iterations = $iterations;
84
        $this->inputLayerFeatures = $inputLayerFeatures;
85
        $this->hiddenLayers = $hiddenLayers;
86
        $this->activationFunction = $activationFunction;
87
        $this->learningRate = $learningRate;
88
89
        $this->initNetwork();
90
    }
91
92
    public function train(array $samples, array $targets): void
93
    {
94
        $this->reset();
95
        $this->initNetwork();
96
        $this->partialTrain($samples, $targets, $this->classes);
97
    }
98
99
    /**
100
     * @throws InvalidArgumentException
101
     */
102
    public function partialTrain(array $samples, array $targets, array $classes = []): void
103
    {
104
        if (count($classes) > 0 && array_values($classes) !== $this->classes) {
105
            // We require the list of classes in the constructor.
106
            throw new InvalidArgumentException(
107
                'The provided classes don\'t match the classes provided in the constructor'
108
            );
109
        }
110
111
        for ($i = 0; $i < $this->iterations; ++$i) {
112
            $this->trainSamples($samples, $targets);
113
        }
114
    }
115
116
    public function setLearningRate(float $learningRate): void
117
    {
118
        $this->learningRate = $learningRate;
119
        $this->backpropagation->setLearningRate($this->learningRate);
120
    }
121
122
    public function getOutput(): array
123
    {
124
        $result = [];
125
        foreach ($this->getOutputLayer()->getNodes() as $i => $neuron) {
126
            $result[$this->classes[$i]] = $neuron->getOutput();
127
        }
128
129
        return $result;
130
    }
131
132
    public function getLearningRate(): float
133
    {
134
        return $this->learningRate;
135
    }
136
137
    public function getBackpropagation(): Backpropagation
138
    {
139
        return $this->backpropagation;
140
    }
141
142
    public function saveTrainingIntoJsonFile(string $filename): void
143
    {
144
        $characteristics = $this->getTrainedCharacteristics();
145
        file_put_contents($filename, json_encode($characteristics));
146
    }
147
148
    public function loadTrainingFromJsonFile(string $filename): void
149
    {
150
        if (!file_exists($filename)) {
151
            throw new InvalidArgumentException('File does not exist, it cannot be loaded');
152
        }
153
        $characteristics = file_get_contents($filename);
154
        if ($characteristics !== false) {
155
            $characteristics = json_decode($characteristics, true);
156
            $this->setTrainedCharacteristics($characteristics);
157
        } else {
158
            throw new InvalidArgumentException('Training file could not be loaded');
159
        }
160
    }
161
162
    /**
163
     * @param mixed $target
164
     */
165
    abstract protected function trainSample(array $sample, $target): void;
166
167
    /**
168
     * @return mixed
169
     */
170
    abstract protected function predictSample(array $sample);
171
172
    protected function reset(): void
173
    {
174
        $this->removeLayers();
175
    }
176
177
    private function initNetwork(): void
178
    {
179
        $this->addInputLayer($this->inputLayerFeatures);
180
        $this->addNeuronLayers($this->hiddenLayers, $this->activationFunction);
181
182
        // Sigmoid function for the output layer as we want a value from 0 to 1.
183
        $sigmoid = new Sigmoid();
184
        $this->addNeuronLayers([count($this->classes)], $sigmoid);
185
186
        $this->addBiasNodes();
187
        $this->generateSynapses();
188
189
        $this->backpropagation = new Backpropagation($this->learningRate);
190
    }
191
192
    private function addInputLayer(int $nodes): void
193
    {
194
        $this->addLayer(new Layer($nodes, Input::class));
195
    }
196
197
    private function addNeuronLayers(array $layers, ?ActivationFunction $defaultActivationFunction = null): void
198
    {
199
        foreach ($layers as $layer) {
200
            if (is_array($layer)) {
201
                $function = $layer[1] instanceof ActivationFunction ? $layer[1] : $defaultActivationFunction;
202
                $this->addLayer(new Layer($layer[0], Neuron::class, $function));
203
            } elseif ($layer instanceof Layer) {
204
                $this->addLayer($layer);
205
            } else {
206
                $this->addLayer(new Layer($layer, Neuron::class, $defaultActivationFunction));
207
            }
208
        }
209
    }
210
211
    private function generateSynapses(): void
212
    {
213
        $layersNumber = count($this->layers) - 1;
214
        for ($i = 0; $i < $layersNumber; ++$i) {
215
            $currentLayer = $this->layers[$i];
216
            $nextLayer = $this->layers[$i + 1];
217
            $this->generateLayerSynapses($nextLayer, $currentLayer);
218
        }
219
    }
220
221
    private function addBiasNodes(): void
222
    {
223
        $biasLayers = count($this->layers) - 1;
224
        for ($i = 0; $i < $biasLayers; ++$i) {
225
            $this->layers[$i]->addNode(new Bias());
226
        }
227
    }
228
229
    private function generateLayerSynapses(Layer $nextLayer, Layer $currentLayer): void
230
    {
231
        foreach ($nextLayer->getNodes() as $nextNeuron) {
232
            if ($nextNeuron instanceof Neuron) {
233
                $this->generateNeuronSynapses($currentLayer, $nextNeuron);
234
            }
235
        }
236
    }
237
238
    private function generateNeuronSynapses(Layer $currentLayer, Neuron $nextNeuron): void
239
    {
240
        foreach ($currentLayer->getNodes() as $currentNeuron) {
241
            $nextNeuron->addSynapse(new Synapse($currentNeuron));
242
        }
243
    }
244
245
    private function trainSamples(array $samples, array $targets): void
246
    {
247
        foreach ($targets as $key => $target) {
248
            $this->trainSample($samples[$key], $target);
249
        }
250
    }
251
}
252