Passed
Pull Request — master (#87)
by Simon
01:22
created

HyperactiveSearchCV.score()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 4
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
from sklearn.base import BaseEstimator, clone
7
from sklearn.metrics import check_scoring
8
from sklearn.utils.validation import indexable, _check_method_params
9
10
11
from hyperactive import Hyperactive
12
13
from .objective_function_adapter import ObjectiveFunctionAdapter
14
from .best_estimator import BestEstimator
15
16
17
class HyperactiveSearchCV(BaseEstimator, BestEstimator):
18
    _required_parameters = ["estimator", "optimizer", "params_config"]
19
20
    def __init__(
21
        self,
22
        estimator,
23
        optimizer,
24
        params_config,
25
        n_iter=100,
26
        *,
27
        scoring=None,
28
        n_jobs=1,
29
        random_state=None,
30
        refit=True,
31
        cv=None,
32
    ):
33
        self.estimator = estimator
34
        self.optimizer = optimizer
35
        self.params_config = params_config
36
        self.n_iter = n_iter
37
        self.scoring = scoring
38
        self.n_jobs = n_jobs
39
        self.random_state = random_state
40
        self.refit = refit
41
        self.cv = cv
42
43
    def _refit(
44
        self,
45
        X,
46
        y=None,
47
        **fit_params,
48
    ):
49
        self.best_estimator_ = clone(self.estimator)
50
        self.best_estimator_.fit(X, y, **fit_params)
51
        return self
52
53
    def fit(self, X, y, **params):
54
        X, y = indexable(X, y)
55
        X, y = self._validate_data(X, y)
56
57
        params = _check_method_params(X, params=params)
58
        self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)
59
60
        objective_function_adapter = ObjectiveFunctionAdapter(
61
            self.estimator,
62
        )
63
        objective_function_adapter.add_dataset(X, y)
64
        objective_function_adapter.add_validation(self.scorer_, self.cv)
65
66
        hyper = Hyperactive(verbosity=False)
67
        hyper.add_search(
68
            objective_function_adapter.objective_function,
69
            search_space=self.params_config,
70
            optimizer=self.optimizer,
71
            n_iter=self.n_iter,
72
            n_jobs=self.n_jobs,
73
            random_state=self.random_state,
74
        )
75
        hyper.run()
76
77
        if self.refit:
78
            self._refit(X, y, **params)
79
80
        return self
81
82
    def score(self, X, y=None, **params):
83
        return self.scorer_(self.best_estimator_, X, y, **params)
84