Completed
Push — develop ( 95bfc8...365a9b )
by Arkadiusz
02:42
created

LeastSquares::predictSample()   A

Complexity

Conditions 2
Paths 2

Size

Total Lines 9
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 1
Metric Value
dl 0
loc 9
rs 9.6666
c 1
b 0
f 1
cc 2
eloc 5
nc 2
nop 1
1
<?php
2
3
declare (strict_types = 1);
4
5
namespace Phpml\Regression;
6
7
use Phpml\Helper\Predictable;
8
use Phpml\Math\Matrix;
9
10
class LeastSquares implements Regression
11
{
12
    use Predictable;
13
    /**
14
     * @var array
15
     */
16
    private $samples;
17
18
    /**
19
     * @var array
20
     */
21
    private $targets;
22
23
    /**
24
     * @var float
25
     */
26
    private $intercept;
27
28
    /**
29
     * @var array
30
     */
31
    private $coefficients;
32
33
    /**
34
     * @param array $samples
35
     * @param array $targets
36
     */
37
    public function train(array $samples, array $targets)
38
    {
39
        $this->samples = $samples;
40
        $this->targets = $targets;
41
42
        $this->computeCoefficients();
43
    }
44
45
    /**
46
     * @param array $sample
47
     *
48
     * @return mixed
49
     */
50
    public function predictSample(array $sample)
51
    {
52
        $result = $this->intercept;
53
        foreach ($this->coefficients as $index => $coefficient) {
54
            $result += $coefficient * $sample[$index];
55
        }
56
57
        return $result;
58
    }
59
60
    /**
61
     * @return array
62
     */
63
    public function getCoefficients()
64
    {
65
        return $this->coefficients;
66
    }
67
68
    /**
69
     * @return float
70
     */
71
    public function getIntercept()
72
    {
73
        return $this->intercept;
74
    }
75
76
    /**
77
     * coefficient(b) = (X'X)-1X'Y.
78
     */
79
    private function computeCoefficients()
80
    {
81
        $samplesMatrix = $this->getSamplesMatrix();
82
        $targetsMatrix = $this->getTargetsMatrix();
83
84
        $ts = $samplesMatrix->transpose()->multiply($samplesMatrix)->inverse();
85
        $tf = $samplesMatrix->transpose()->multiply($targetsMatrix);
86
87
        $this->coefficients = $ts->multiply($tf)->getColumnValues(0);
88
        $this->intercept = array_shift($this->coefficients);
89
    }
90
91
    /**
92
     * Add one dimension for intercept calculation.
93
     *
94
     * @return Matrix
95
     */
96
    private function getSamplesMatrix()
97
    {
98
        $samples = [];
99
        foreach ($this->samples as $sample) {
100
            array_unshift($sample, 1);
101
            $samples[] = $sample;
102
        }
103
104
        return new Matrix($samples);
105
    }
106
107
    /**
108
     * @return Matrix
109
     */
110
    private function getTargetsMatrix()
111
    {
112
        if (is_array($this->targets[0])) {
113
            return new Matrix($this->targets);
114
        }
115
116
        return Matrix::fromFlatArray($this->targets);
117
    }
118
}
119