HyperactiveSearchCV._check_data()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
"""Hyperactive cross-validation search for scikit-learn integration.
2
3
Author: Simon Blanke
4
Email: [email protected]
5
License: MIT License
6
"""
7
8
from collections.abc import Callable
9
from typing import Union
10
11
from sklearn.base import BaseEstimator, clone
12
from sklearn.base import BaseEstimator as SklearnBaseEstimator
13
from sklearn.metrics import check_scoring
14
15
from hyperactive import Hyperactive
16
from hyperactive.experiment.integrations.sklearn_cv import SklearnCvExperiment
17
18
from ...optimizers import RandomSearchOptimizer
19
from ._compat import _check_method_params, _safe_refit, _safe_validate_X_y
20
from .best_estimator import BestEstimator as _BestEstimator_
21
from .checks import Checks
22
23
24
class HyperactiveSearchCV(BaseEstimator, _BestEstimator_, Checks):
25
    """HyperactiveSearchCV class for hyperparameter tuning with sklearn.
26
27
    This class provides a hyperparameter tuning interface compatible with sklearn.
28
29
    Parameters
30
    ----------
31
    - estimator: SklearnBaseEstimator
32
        The estimator to be tuned.
33
    - params_config: dict[str, list]
34
        Dictionary containing the hyperparameter search space.
35
    - optimizer: Union[str, type[RandomSearchOptimizer]], optional
36
        The optimizer to be used for hyperparameter search, default is "default".
37
    - n_iter: int, optional
38
        Number of parameter settings that are sampled, default is 100.
39
    - scoring: Callable | str | None, optional
40
        Scoring method to evaluate the predictions on the test set.
41
    - n_jobs: int, optional
42
        Number of jobs to run in parallel, default is 1.
43
    - random_state: int | None, optional
44
        Random seed for reproducibility.
45
    - refit: bool, optional
46
        Refit the best estimator with the entire dataset, default is True.
47
    - cv: int | "BaseCrossValidator" | Iterable | None, optional
48
        Determines the cross-validation splitting strategy.
49
50
    Methods
51
    -------
52
    - fit(X, y, **fit_params)
53
        Fit the estimator and tune hyperparameters.
54
    - score(X, y, **params)
55
        Return the score of the best estimator on the input data.
56
    """
57
58
    _required_parameters = ["estimator", "optimizer", "params_config"]
59
60
    def __init__(
61
        self,
62
        estimator: "SklearnBaseEstimator",
63
        params_config: dict[str, list],
64
        optimizer: Union[str, type[RandomSearchOptimizer]] = "default",
65
        n_iter: int = 100,
66
        *,
67
        scoring: Union[Callable, str, None] = None,
68
        n_jobs: int = 1,
69
        random_state: Union[int, None] = None,
70
        refit: bool = True,
71
        cv=None,
72
    ):
73
        super().__init__()
74
75
        self.estimator = estimator
76
        self.params_config = params_config
77
        self.optimizer = optimizer
78
        self.n_iter = n_iter
79
        self.scoring = scoring
80
        self.n_jobs = n_jobs
81
        self.random_state = random_state
82
        self.refit = refit
83
        self.cv = cv
84
85
    def _refit(self, X, y=None, **fit_params):
86
        self.best_estimator_ = clone(self.estimator).set_params(
87
            **clone(self.best_params_, safe=False)
88
        )
89
90
        self.best_estimator_.fit(X, y, **fit_params)
91
        return self
92
93
    def _check_data(self, X, y):
94
        return _safe_validate_X_y(self, X, y)
95
96
    @Checks.verify_fit
97
    def fit(self, X, y, **fit_params):
98
        """
99
        Fit the estimator using the provided training data.
100
101
        Parameters
102
        ----------
103
        - X: array-like or sparse matrix, shape (n_samples, n_features)
104
            The training input samples.
105
        - y: array-like, shape (n_samples,) or (n_samples, n_outputs)
106
            The target values.
107
        - **fit_params: dict of string -> object
108
            Additional fit parameters.
109
110
        Returns
111
        -------
112
        - self: object
113
            Returns the instance itself.
114
        """
115
        X, y = self._check_data(X, y)
116
117
        fit_params = _check_method_params(X, params=fit_params)
118
        self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)
119
120
        experiment = SklearnCvExperiment(
121
            estimator=self.estimator,
122
            scoring=self.scorer_,
123
            cv=self.cv,
124
            X=X,
125
            y=y,
126
        )
127
        objective_function = experiment.score
128
129
        hyper = Hyperactive(verbosity=False)
130
        hyper.add_search(
131
            objective_function,
132
            search_space=self.params_config,
133
            optimizer=self.optimizer,
134
            n_iter=self.n_iter,
135
            n_jobs=self.n_jobs,
136
            random_state=self.random_state,
137
        )
138
        hyper.run()
139
140
        self.best_params_ = hyper.best_para(objective_function)
141
        self.best_score_ = hyper.best_score(objective_function)
142
        self.search_data_ = hyper.search_data(objective_function)
143
144
        _safe_refit(self, X, y, fit_params)
145
146
        return self
147
148
    def score(self, X, y=None, **params):
149
        """
150
        Calculate the score of the best estimator on the input data.
151
152
        Parameters
153
        ----------
154
        - X: array-like or sparse matrix of shape (n_samples, n_features)
155
            The input samples.
156
        - y: array-like of shape (n_samples,), default=None
157
            The target values.
158
        - **params: dict
159
            Additional parameters to be passed to the scoring function.
160
161
        Returns
162
        -------
163
        - float
164
            The score of the best estimator on the input data.
165
        """
166
        return self.scorer_(self.best_estimator_, X, y, **params)
167
168
    @property
169
    def fit_successful(self):
170
        """Fit Successful function."""
171
        self._fit_successful
172