Completed
Push — develop ( 633974...f7b91b )
by Arkadiusz
02:36
created

LeastSquares::train()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 7
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 3
Bugs 0 Features 1
Metric Value
c 3
b 0
f 1
dl 0
loc 7
rs 9.4285
cc 1
eloc 4
nc 1
nop 2
1
<?php
2
3
declare (strict_types = 1);
4
5
namespace Phpml\Regression;
6
7
use Phpml\Math\Matrix;
8
9
class LeastSquares implements Regression
10
{
11
    /**
12
     * @var array
13
     */
14
    private $samples;
15
16
    /**
17
     * @var array
18
     */
19
    private $targets;
20
21
    /**
22
     * @var float
23
     */
24
    private $intercept;
25
26
    /**
27
     * @var array
28
     */
29
    private $coefficients;
30
31
    /**
32
     * @param array $samples
33
     * @param array $targets
34
     */
35
    public function train(array $samples, array $targets)
36
    {
37
        $this->samples = $samples;
38
        $this->targets = $targets;
39
40
        $this->computeCoefficients();
41
    }
42
43
    /**
44
     * @param array $sample
45
     *
46
     * @return mixed
47
     */
48
    public function predict($sample)
49
    {
50
        $result = $this->intercept;
51
        foreach ($this->coefficients as $index => $coefficient) {
52
            $result += $coefficient * $sample[$index];
53
        }
54
55
        return $result;
56
    }
57
58
    /**
59
     * @return array
60
     */
61
    public function getCoefficients()
62
    {
63
        return $this->coefficients;
64
    }
65
66
    /**
67
     * @return float
68
     */
69
    public function getIntercept()
70
    {
71
        return $this->intercept;
72
    }
73
74
    /**
75
     * coefficient(b) = (X'X)-1X'Y.
76
     */
77
    private function computeCoefficients()
78
    {
79
        $samplesMatrix = $this->getSamplesMatrix();
80
        $targetsMatrix = $this->getTargetsMatrix();
81
82
        $ts = $samplesMatrix->transpose()->multiply($samplesMatrix)->inverse();
83
        $tf = $samplesMatrix->transpose()->multiply($targetsMatrix);
84
85
        $this->coefficients = $ts->multiply($tf)->getColumnValues(0);
86
        $this->intercept = array_shift($this->coefficients);
87
    }
88
89
    /**
90
     * Add one dimension for intercept calculation.
91
     *
92
     * @return Matrix
93
     */
94
    private function getSamplesMatrix()
95
    {
96
        $samples = [];
97
        foreach ($this->samples as $sample) {
98
            array_unshift($sample, 1);
99
            $samples[] = $sample;
100
        }
101
102
        return new Matrix($samples);
103
    }
104
105
    /**
106
     * @return Matrix
107
     */
108
    private function getTargetsMatrix()
109
    {
110
        if (is_array($this->targets[0])) {
111
            return new Matrix($this->targets);
112
        }
113
114
        return Matrix::fromFlatArray($this->targets);
115
    }
116
}
117