Completed
Push — develop ( 3e4dc3...633974 )
by Arkadiusz
02:40
created

LeastSquares::computeCoefficients()   A

Complexity

Conditions 1
Paths 1

Size

Total Lines 11
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
c 1
b 0
f 0
dl 0
loc 11
rs 9.4285
cc 1
eloc 7
nc 1
nop 0
1
<?php
2
3
declare (strict_types = 1);
4
5
namespace Phpml\Regression;
6
7
use Phpml\Math\Matrix;
8
9
class LeastSquares implements Regression
10
{
11
    /**
12
     * @var array
13
     */
14
    private $samples;
15
16
    /**
17
     * @var array
18
     */
19
    private $targets;
20
21
    /**
22
     * @var float
23
     */
24
    private $intercept;
25
26
    /**
27
     * @var array
28
     */
29
    private $coefficients;
30
31
    /**
32
     * @param array $samples
33
     * @param array $targets
34
     */
35
    public function train(array $samples, array $targets)
36
    {
37
        $this->samples = $samples;
38
        $this->targets = $targets;
39
40
        $this->computeCoefficients();
41
    }
42
43
    /**
44
     * @param array $sample
45
     *
46
     * @return mixed
47
     */
48
    public function predict($sample)
49
    {
50
        $result = $this->intercept;
51
        foreach ($this->coefficients as $index => $coefficient) {
52
            $result += $coefficient * $sample[$index];
53
        }
54
55
        return $result;
56
    }
57
58
    /**
59
     * @return array
60
     */
61
    public function getCoefficients()
62
    {
63
        return $this->coefficients;
64
    }
65
66
    /**
67
     * coefficient(b) = (X'X)-1X'Y.
68
     */
69
    private function computeCoefficients()
70
    {
71
        $samplesMatrix = new Matrix($this->samples);
72
        $targetsMatrix = new Matrix($this->targets);
73
74
        $ts = $samplesMatrix->transpose()->multiply($samplesMatrix)->inverse();
75
        $tf = $samplesMatrix->transpose()->multiply($targetsMatrix);
76
77
        $this->coefficients = $ts->multiply($tf)->getColumnValues(0);
78
        $this->intercept = array_shift($this->coefficients);
79
    }
80
}
81