Passed
Branch master (6cbf2c)
by Iván
05:35 queued 02:43
created

GradientDescendent::doStep()   A

Complexity

Conditions 4
Paths 8

Size

Total Lines 21
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 15
CRAP Score 4

Importance

Changes 1
Bugs 1 Features 0
Metric Value
c 1
b 1
f 0
dl 0
loc 21
ccs 15
cts 15
cp 1
rs 9.0534
cc 4
eloc 12
nc 8
nop 5
crap 4
1
<?php
2
3
namespace MachineLearning\Application\Algorithm;
4
5
use MachineLearning\Domain\Exception\DivergenceException;
6
use MachineLearning\Domain\Model\Dataset;
7
use MachineLearning\Domain\Hypothesis\HypothesisInterface;
8
use MachineLearning\Domain\Model\Result;
9
use MachineLearning\Domain\Model\ValueInterface;
10
use MachineLearning\Domain\Model\Value\VectorValue;
11
12
class GradientDescendent implements AlgorithmInterface
13
{
14
    const DEFAULT_LEARNING_RATE = 0.01;
15
16
    const DEFAULT_CONVERGENCE_CRITERIA = 0.0000001;
17
18
    const DEFAULT_DIVERGENCE_CRITERIA = 100;
19
20
    /**
21
     * @var HypothesisInterface
22
     */
23
    protected $hypothesis;
24
25
    /**
26
     * @var double
27
     */
28
    protected $learningRate;
29
30
    /**
31
     * @var double
32
     */
33
    protected $convergenceCriteria;
34
35
    /**
36
     * @var double
37
     */
38
    protected $divergenceCriteria;
39
40
    /**
41
     * @param HypothesisInterface $hypothesis
42
     * @param double $learningRate
43
     * @param double $convergenceCriteria
44
     * @param double $divergenceCriteria
45
     */
46 1
    public function __construct(
47
        HypothesisInterface $hypothesis,
48
        $learningRate = self::DEFAULT_LEARNING_RATE,
49
        $convergenceCriteria = self::DEFAULT_CONVERGENCE_CRITERIA,
50
        $divergenceCriteria = self::DEFAULT_DIVERGENCE_CRITERIA
51
    ) {
52 1
        $this->hypothesis = $hypothesis;
53 1
        $this->learningRate = (double) $learningRate;
54 1
        $this->convergenceCriteria = (double) $convergenceCriteria;
55 1
        $this->divergenceCriteria = (double) $divergenceCriteria;
56 1
    }
57
58
    /**
59
     * @param Dataset $dataset
60
     * @return ValueInterface
61
     * @throws DivergenceException
62
     */
63 1
    public function train(Dataset $dataset) {
64 1
        $convergence = false;
65 1
        $divergence = false;
66
67 1
        $firstResult = $dataset->current();
68 1
        $features = count($firstResult->getIndependentVariable()->getValue());
69 1
        $total = count($dataset);
70
71 1
        $coefficientVector = array_fill(0, $features+1, 1);
72 1
        $coefficientVector[0] = 0;
73 1
        $incrementVector = array_fill(0, $features+1, 0);
74
75 1
        while(!$convergence && !$divergence) {
76 1
            list($convergence, $divergence, $coefficientVector) = $this->doStep(
77 1
                $dataset,
78 1
                $coefficientVector,
79 1
                $features,
80 1
                $total,
81
                $incrementVector
82 1
            );
83 1
        }
84
85 1
        if ($divergence) {
86
            throw new DivergenceException();
87
        }
88
89 1
        return new VectorValue($coefficientVector);
90
    }
91
92
    /**
93
     * @param Dataset $dataset
94
     * @param $coefficientVector
95
     * @param $features
96
     * @param $total
97
     * @param $incrementVector
98
     * @return array
99
     */
100 1
    protected function doStep(Dataset $dataset, $coefficientVector, $features, $total, $incrementVector)
101
    {
102 1
        $coefficient = new VectorValue($coefficientVector);
103 1
        $costVector = array_fill(0, $features + 1, 0);
104
105 1
        foreach ($dataset as $result) {
106 1
            $costVector = $this->calculateStepCost($features, $coefficient, $result, $costVector);
107 1
        }
108
109 1
        for ($j = 0; $j < $features + 1; $j++) {
110 1
            $incrementVector[$j] = $this->learningRate * -(1 / ($total)) * $costVector[$j];
111 1
        }
112
113 1
        $convergence = (bool)(abs(array_sum($incrementVector)) < $this->convergenceCriteria);
114 1
        $divergence = (bool)(abs(array_sum($incrementVector)) > $this->divergenceCriteria);
115
116 1
        for ($j = 0; $j < $features + 1; $j++) {
117 1
            $coefficientVector[$j] += $incrementVector[$j];
118 1
        }
119 1
        return array($convergence, $divergence, $coefficientVector);
120
    }
121
122
    /**
123
     * @param int $features
124
     * @param ValueInterface $coefficient
125
     * @param Result $result
126
     * @param array $costVector
127
     */
128 1
    protected function calculateStepCost($features, ValueInterface $coefficient, Result $result, array $costVector)
129
    {
130
        $firstOrderIncrement = (
131 1
            (float)$this->hypothesis->calculate(
132 1
                $coefficient,
133 1
                $result->getIndependentVariable()
134 1
            )->getValue()
135 1
            - (float)$result->getDependentVariable()->getValue()
136 1
        );
137 1
        $costVector[0] += $firstOrderIncrement;
138
139 1
        for ($j = 1; $j < $features + 1; $j++) {
140
            $costVector[$j] += $firstOrderIncrement
141 1
                * (float)$this->hypothesis->derivative(
142 1
                    $coefficient,
143 1
                    $result->getIndependentVariable(),
144
                    $j - 1
145 1
                )->getValue();
146 1
        }
147 1
        return $costVector;
148
    }
149
150
}
151