Passed
Push — master ( 7ab80b...4af844 )
by Arkadiusz
03:30
created

MLPClassifier::trainSample()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 9
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 9
rs 9.6666
c 0
b 0
f 0
cc 1
eloc 3
nc 1
nop 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Classification;
6
7
use Phpml\Classification\Classifier;
8
use Phpml\Exception\InvalidArgumentException;
9
use Phpml\NeuralNetwork\Network\MultilayerPerceptron;
10
use Phpml\NeuralNetwork\Training\Backpropagation;
11
use Phpml\NeuralNetwork\ActivationFunction;
12
use Phpml\NeuralNetwork\Layer;
13
use Phpml\NeuralNetwork\Node\Bias;
14
use Phpml\NeuralNetwork\Node\Input;
15
use Phpml\NeuralNetwork\Node\Neuron;
16
use Phpml\NeuralNetwork\Node\Neuron\Synapse;
17
use Phpml\Helper\Predictable;
18
19
class MLPClassifier extends MultilayerPerceptron implements Classifier
20
{
21
22
    /**
23
     * @param  mixed $target
24
     * @return int
25
     */
26
    public function getTargetClass($target): int
27
    {
28
        if (!in_array($target, $this->classes)) {
29
            throw InvalidArgumentException::invalidTarget($target);
30
        }
31
        return array_search($target, $this->classes);
32
    }
33
34
    /**
35
     * @param array $sample
36
     *
37
     * @return mixed
38
     */
39
    protected function predictSample(array $sample)
40
    {
41
        $output = $this->setInput($sample)->getOutput();
42
43
        $predictedClass = null;
44
        $max = 0;
45
        foreach ($output as $class => $value) {
46
            if ($value > $max) {
47
                $predictedClass = $class;
48
                $max = $value;
49
            }
50
        }
51
        return $this->classes[$predictedClass];
52
    }
53
54
    /**
55
     * @param array $sample
56
     * @param mixed $target
57
     */
58
    protected function trainSample(array $sample, $target)
59
    {
60
61
        // Feed-forward.
62
        $this->setInput($sample)->getOutput();
63
64
        // Back-propagate.
65
        $this->backpropagation->backpropagate($this->getLayers(), $this->getTargetClass($target));
66
    }
67
}
68