Passed
Push — master ( bd7b54...01f089 )
by Simon
01:40 queued 13s
created

HyperactiveSearchCV.__init__()   A

Complexity

Conditions 1

Size

Total Lines 24
Code Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 22
nop 11
dl 0
loc 24
rs 9.352
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
from collections.abc import Iterable, Callable
6
from typing import Union, Dict, Type
7
8
from sklearn.base import BaseEstimator, clone
9
from sklearn.metrics import check_scoring
10
from sklearn.utils.validation import indexable, _check_method_params
11
12
from sklearn.base import BaseEstimator as SklearnBaseEstimator
13
from sklearn.model_selection import BaseCrossValidator
14
15
from hyperactive import Hyperactive
16
17
from .objective_function_adapter import ObjectiveFunctionAdapter
18
from .best_estimator import BestEstimator as _BestEstimator_
19
from .checks import Checks
20
from ...optimizers import RandomSearchOptimizer
21
22
23
class HyperactiveSearchCV(BaseEstimator, _BestEstimator_, Checks):
24
    """
25
    HyperactiveSearchCV class for hyperparameter tuning using cross-validation with sklearn estimators.
26
27
    Parameters:
28
    - estimator: SklearnBaseEstimator
29
        The estimator to be tuned.
30
    - params_config: Dict[str, list]
31
        Dictionary containing the hyperparameter search space.
32
    - optimizer: Union[str, Type[RandomSearchOptimizer]], optional
33
        The optimizer to be used for hyperparameter search, default is "default".
34
    - n_iter: int, optional
35
        Number of parameter settings that are sampled, default is 100.
36
    - scoring: Callable | str | None, optional
37
        Scoring method to evaluate the predictions on the test set.
38
    - n_jobs: int, optional
39
        Number of jobs to run in parallel, default is 1.
40
    - random_state: int | None, optional
41
        Random seed for reproducibility.
42
    - refit: bool, optional
43
        Refit the best estimator with the entire dataset, default is True.
44
    - cv: int | "BaseCrossValidator" | Iterable | None, optional
45
        Determines the cross-validation splitting strategy.
46
47
    Methods:
48
    - fit(X, y, **fit_params)
49
        Fit the estimator and tune hyperparameters.
50
    - score(X, y, **params)
51
        Return the score of the best estimator on the input data.
52
    """
53
54
    _required_parameters = ["estimator", "optimizer", "params_config"]
55
56
    def __init__(
57
        self,
58
        estimator: "SklearnBaseEstimator",
59
        params_config: Dict[str, list],
60
        optimizer: Union[str, Type[RandomSearchOptimizer]] = "default",
61
        n_iter: int = 100,
62
        *,
63
        scoring: Callable | str | None = None,
64
        n_jobs: int = 1,
65
        random_state: int | None = None,
66
        refit: bool = True,
67
        cv=None,
68
    ):
69
        super().__init__()
70
71
        self.estimator = estimator
72
        self.params_config = params_config
73
        self.optimizer = optimizer
74
        self.n_iter = n_iter
75
        self.scoring = scoring
76
        self.n_jobs = n_jobs
77
        self.random_state = random_state
78
        self.refit = refit
79
        self.cv = cv
80
81
    def _refit(self, X, y=None, **fit_params):
82
        self.best_estimator_ = clone(self.estimator).set_params(
83
            **clone(self.best_params_, safe=False)
84
        )
85
86
        self.best_estimator_.fit(X, y, **fit_params)
87
        return self
88
89
    @Checks.verify_fit
90
    def fit(self, X, y, **fit_params):
91
        X, y = indexable(X, y)
92
        X, y = self._validate_data(X, y)
93
94
        fit_params = _check_method_params(X, params=fit_params)
95
        self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)
96
97
        objective_function_adapter = ObjectiveFunctionAdapter(
98
            self.estimator,
99
        )
100
        objective_function_adapter.add_dataset(X, y)
101
        objective_function_adapter.add_validation(self.scorer_, self.cv)
102
        objective_function = objective_function_adapter.objective_function
103
104
        hyper = Hyperactive(verbosity=False)
105
        hyper.add_search(
106
            objective_function,
107
            search_space=self.params_config,
108
            optimizer=self.optimizer,
109
            n_iter=self.n_iter,
110
            n_jobs=self.n_jobs,
111
            random_state=self.random_state,
112
        )
113
        hyper.run()
114
        self.best_params_ = hyper.best_para(objective_function)
115
        self.best_score_ = hyper.best_score(objective_function)
116
117
        self.best_params_ = hyper.best_para(objective_function)
118
        self.best_score_ = hyper.best_score(objective_function)
119
        self.search_data_ = hyper.search_data(objective_function)
120
121
        if self.refit:
122
            self._refit(X, y, **fit_params)
123
124
        return self
125
126
    def score(self, X, y=None, **params):
127
        return self.scorer_(self.best_estimator_, X, y, **params)
128
129
    @property
130
    def fit_successful(self):
131
        self._fit_successful
132