Completed
Push — master ( ce1e03...f67568 )
by Simon
14:21
created

GPR.fit()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 3
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
import numpy as np
7
8
from sklearn.gaussian_process import GaussianProcessRegressor
9
from sklearn.gaussian_process.kernels import Matern
10
from sklearn.ensemble import ExtraTreesRegressor as _ExtraTreesRegressor_
11
from sklearn.ensemble import ExtraTreesRegressor as _RandomForestRegressor_
12
13
14
def _return_std(X, trees, predictions, min_variance):
15
    std = np.zeros(len(X))
16
17
    for tree in trees:
18
        var_tree = tree.tree_.impurity[tree.apply(X)]
19
        var_tree[var_tree < min_variance] = min_variance
20
        mean_tree = tree.predict(X)
21
        std += var_tree + mean_tree ** 2
22
23
    std /= len(trees)
24
    std -= predictions ** 2.0
25
    std[std < 0.0] = 0.0
26
    std = std ** 0.5
27
    # print("std", std)
28
    return std
29
30
31
class TreeEnsembleBase:
32
    def __init__(self, min_variance=0.0, **kwargs):
33
        self.min_variance = min_variance
34
        super().__init__(**kwargs)
35
36
    def fit(self, X, y):
37
        super().fit(X, np.ravel(y))
38
39
    def predict(self, X, return_std=False):
40
        mean = super().predict(X)
41
42
        if return_std:
43
            if self.criterion != "mse":
44
                raise ValueError(
45
                    "Expected impurity to be 'mse', got %s instead" % self.criterion
46
                )
47
            std = _return_std(X, self.estimators_, mean, self.min_variance)
48
            return mean.reshape(-1, 1), std
49
        return mean.reshape(-1, 1)
50
51
52
class RandomForestRegressor(TreeEnsembleBase, _RandomForestRegressor_):
53
    def __init__(self, min_variance=0.0, **kwargs):
54
        super().__init__(**kwargs)
55
56
57
class ExtraTreesRegressor(TreeEnsembleBase, _ExtraTreesRegressor_):
58
    def __init__(self, min_variance=0.0, **kwargs):
59
        super().__init__(**kwargs)
60
61
62
class GPR:
63
    def __init__(self):
64
        self.gpr = GaussianProcessRegressor(kernel=Matern(nu=2.5), normalize_y=True)
65
66
    def fit(self, X, y):
67
        self.gpr.fit(X, y)
68
69
    def predict(self, X, return_std=False):
70
        return self.gpr.predict(X, return_std=return_std)
71