Passed
Pull Request — master (#87)
by Simon
01:18
created

hyperactive.integrations.sklearn.hyperactive_search_cv   A

Complexity

Total Complexity 2

Size/Duplication

Total Lines 66
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 48
dl 0
loc 66
rs 10
c 0
b 0
f 0
wmc 2

2 Methods

Rating   Name   Duplication   Size   Complexity  
A HyperactiveSearchCV.fit() 0 25 1
A HyperactiveSearchCV.__init__() 0 22 1
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
from sklearn.base import BaseEstimator
7
from sklearn.metrics import check_scoring
8
from sklearn.utils.validation import indexable, _check_method_params
9
10
from hyperactive import Hyperactive
11
12
from .objective_function_adapter import ObjectiveFunctionAdapter
13
14
15
class HyperactiveSearchCV(BaseEstimator):
16
    _required_parameters = ["estimator", "optimizer", "params_config"]
17
18
    def __init__(
19
        self,
20
        estimator,
21
        optimizer,
22
        params_config,
23
        n_iter=100,
24
        *,
25
        scoring=None,
26
        n_jobs=1,
27
        random_state=None,
28
        refit=True,
29
        cv=None,
30
    ):
31
        self.estimator = estimator
32
        self.optimizer = optimizer
33
        self.params_config = params_config
34
        self.n_iter = n_iter
35
        self.scoring = scoring
36
        self.n_jobs = n_jobs
37
        self.random_state = random_state
38
        self.refit = refit
39
        self.cv = cv
40
41
    def fit(self, X, y, **params):
42
        X, y = indexable(X, y)
43
        X, y = self._validate_data(X, y)
44
45
        params = _check_method_params(X, params=params)
46
        self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)
47
48
        objective_function_adapter = ObjectiveFunctionAdapter(
49
            self.estimator,
50
        )
51
        objective_function_adapter.add_dataset(X, y)
52
        objective_function_adapter.add_validation(self.scorer_, self.cv)
53
54
        hyper = Hyperactive(verbosity=False)
55
        hyper.add_search(
56
            objective_function_adapter.objective_function,
57
            search_space=self.params_config,
58
            optimizer=self.optimizer,
59
            n_iter=self.n_iter,
60
            n_jobs=self.n_jobs,
61
            random_state=self.random_state,
62
        )
63
        hyper.run()
64
65
        return self
66