Passed
Push — master ( bd7b54...01f089 )
by Simon
01:40 queued 13s
created

HyperactiveSearchCV._refit()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 4
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
from collections.abc import Iterable, Callable
6
from typing import Union, Dict, Type
7
8
from sklearn.base import BaseEstimator, clone
9
from sklearn.metrics import check_scoring
10
from sklearn.utils.validation import indexable, _check_method_params
11
12
from sklearn.base import BaseEstimator as SklearnBaseEstimator
13
from sklearn.model_selection import BaseCrossValidator
14
15
from hyperactive import Hyperactive
16
17
from .objective_function_adapter import ObjectiveFunctionAdapter
18
from .best_estimator import BestEstimator as _BestEstimator_
19
from .checks import Checks
20
from ...optimizers import RandomSearchOptimizer
21
22
23
class HyperactiveSearchCV(BaseEstimator, _BestEstimator_, Checks):
24
    """
25
    HyperactiveSearchCV class for hyperparameter tuning using cross-validation with sklearn estimators.
26
27
    Parameters:
28
    - estimator: SklearnBaseEstimator
29
        The estimator to be tuned.
30
    - params_config: Dict[str, list]
31
        Dictionary containing the hyperparameter search space.
32
    - optimizer: Union[str, Type[RandomSearchOptimizer]], optional
33
        The optimizer to be used for hyperparameter search, default is "default".
34
    - n_iter: int, optional
35
        Number of parameter settings that are sampled, default is 100.
36
    - scoring: Callable | str | None, optional
37
        Scoring method to evaluate the predictions on the test set.
38
    - n_jobs: int, optional
39
        Number of jobs to run in parallel, default is 1.
40
    - random_state: int | None, optional
41
        Random seed for reproducibility.
42
    - refit: bool, optional
43
        Refit the best estimator with the entire dataset, default is True.
44
    - cv: int | "BaseCrossValidator" | Iterable | None, optional
45
        Determines the cross-validation splitting strategy.
46
47
    Methods:
48
    - fit(X, y, **fit_params)
49
        Fit the estimator and tune hyperparameters.
50
    - score(X, y, **params)
51
        Return the score of the best estimator on the input data.
52
    """
53
54
    _required_parameters = ["estimator", "optimizer", "params_config"]
55
56
    def __init__(
57
        self,
58
        estimator: "SklearnBaseEstimator",
59
        params_config: Dict[str, list],
60
        optimizer: Union[str, Type[RandomSearchOptimizer]] = "default",
61
        n_iter: int = 100,
62
        *,
63
        scoring: Callable | str | None = None,
64
        n_jobs: int = 1,
65
        random_state: int | None = None,
66
        refit: bool = True,
67
        cv=None,
68
    ):
69
        super().__init__()
70
71
        self.estimator = estimator
72
        self.params_config = params_config
73
        self.optimizer = optimizer
74
        self.n_iter = n_iter
75
        self.scoring = scoring
76
        self.n_jobs = n_jobs
77
        self.random_state = random_state
78
        self.refit = refit
79
        self.cv = cv
80
81
    def _refit(self, X, y=None, **fit_params):
82
        self.best_estimator_ = clone(self.estimator).set_params(
83
            **clone(self.best_params_, safe=False)
84
        )
85
86
        self.best_estimator_.fit(X, y, **fit_params)
87
        return self
88
89
    @Checks.verify_fit
90
    def fit(self, X, y, **fit_params):
91
        X, y = indexable(X, y)
92
        X, y = self._validate_data(X, y)
93
94
        fit_params = _check_method_params(X, params=fit_params)
95
        self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)
96
97
        objective_function_adapter = ObjectiveFunctionAdapter(
98
            self.estimator,
99
        )
100
        objective_function_adapter.add_dataset(X, y)
101
        objective_function_adapter.add_validation(self.scorer_, self.cv)
102
        objective_function = objective_function_adapter.objective_function
103
104
        hyper = Hyperactive(verbosity=False)
105
        hyper.add_search(
106
            objective_function,
107
            search_space=self.params_config,
108
            optimizer=self.optimizer,
109
            n_iter=self.n_iter,
110
            n_jobs=self.n_jobs,
111
            random_state=self.random_state,
112
        )
113
        hyper.run()
114
        self.best_params_ = hyper.best_para(objective_function)
115
        self.best_score_ = hyper.best_score(objective_function)
116
117
        self.best_params_ = hyper.best_para(objective_function)
118
        self.best_score_ = hyper.best_score(objective_function)
119
        self.search_data_ = hyper.search_data(objective_function)
120
121
        if self.refit:
122
            self._refit(X, y, **fit_params)
123
124
        return self
125
126
    def score(self, X, y=None, **params):
127
        return self.scorer_(self.best_estimator_, X, y, **params)
128
129
    @property
130
    def fit_successful(self):
131
        self._fit_successful
132