Completed
Push — develop ( 2412f1...f0bd5a )
by Arkadiusz
04:16
created

MLPRegressor::train()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 9
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
c 1
b 0
f 0
dl 0
loc 9
rs 9.6666
cc 1
eloc 5
nc 1
nop 2
1
<?php
2
3
declare (strict_types = 1);
4
5
namespace Phpml\Regression;
6
7
8
use Phpml\Helper\Predictable;
9
use Phpml\NeuralNetwork\ActivationFunction;
10
use Phpml\NeuralNetwork\Network\MultilayerPerceptron;
11
use Phpml\NeuralNetwork\Training\Backpropagation;
12
13
class MLPRegressor implements Regression
14
{
15
    use Predictable;
16
17
    /**
18
     * @var MultilayerPerceptron
19
     */
20
    private $perceptron;
21
22
    /**
23
     * @var array
24
     */
25
    private $hiddenLayers;
26
27
    /**
28
     * @var float
29
     */
30
    private $desiredError;
31
32
    /**
33
     * @var int
34
     */
35
    private $maxIterations;
36
37
    /**
38
     * @var ActivationFunction
39
     */
40
    private $activationFunction;
41
42
    /**
43
     * @param array $hiddenLayers
44
     * @param float $desiredError
45
     * @param int $maxIterations
46
     * @param ActivationFunction $activationFunction
47
     */
48
    public function __construct(array $hiddenLayers = [100], float $desiredError, int $maxIterations, ActivationFunction $activationFunction = null)
49
    {
50
        $this->hiddenLayers = $hiddenLayers;
51
        $this->desiredError = $desiredError;
52
        $this->maxIterations = $maxIterations;
53
        $this->activationFunction = $activationFunction;
54
    }
55
56
57
    /**
58
     * @param array $samples
59
     * @param array $targets
60
     */
61
    public function train(array $samples, array $targets)
62
    {
63
        $layers = [count($samples[0])] + $this->hiddenLayers + [count($targets[0])];
64
65
        $this->perceptron = new MultilayerPerceptron($layers, $this->activationFunction);
66
67
        $trainer = new Backpropagation($this->perceptron);
68
        $trainer->train($samples, $targets, $this->desiredError, $this->maxIterations);
69
    }
70
71
    /**
72
     * @param array $sample
73
     *
74
     * @return array
75
     */
76
    protected function predictSample(array $sample)
77
    {
78
        return $this->perceptron->setInput($sample)->getOutput();
79
    }
80
81
}
82