1
|
|
|
# Author: Simon Blanke |
2
|
|
|
# Email: [email protected] |
3
|
|
|
# License: MIT License |
4
|
|
|
|
5
|
|
|
from collections.abc import Iterable, Callable |
6
|
|
|
from typing import Union, Dict, Type |
7
|
|
|
|
8
|
|
|
from sklearn.base import BaseEstimator, clone |
9
|
|
|
from sklearn.metrics import check_scoring |
10
|
|
|
from sklearn.utils.validation import indexable, _check_method_params |
11
|
|
|
|
12
|
|
|
from sklearn.base import BaseEstimator as SklearnBaseEstimator |
13
|
|
|
from sklearn.model_selection import BaseCrossValidator |
14
|
|
|
|
15
|
|
|
from hyperactive import Hyperactive |
16
|
|
|
|
17
|
|
|
from .objective_function_adapter import ObjectiveFunctionAdapter |
18
|
|
|
from .best_estimator import BestEstimator as _BestEstimator_ |
19
|
|
|
from .checks import Checks |
20
|
|
|
from ...optimizers import RandomSearchOptimizer |
21
|
|
|
|
22
|
|
|
|
23
|
|
|
class HyperactiveSearchCV(BaseEstimator, _BestEstimator_, Checks): |
24
|
|
|
""" |
25
|
|
|
HyperactiveSearchCV class for hyperparameter tuning using cross-validation with sklearn estimators. |
26
|
|
|
|
27
|
|
|
Parameters: |
28
|
|
|
- estimator: SklearnBaseEstimator |
29
|
|
|
The estimator to be tuned. |
30
|
|
|
- params_config: Dict[str, list] |
31
|
|
|
Dictionary containing the hyperparameter search space. |
32
|
|
|
- optimizer: Union[str, Type[RandomSearchOptimizer]], optional |
33
|
|
|
The optimizer to be used for hyperparameter search, default is "default". |
34
|
|
|
- n_iter: int, optional |
35
|
|
|
Number of parameter settings that are sampled, default is 100. |
36
|
|
|
- scoring: Callable | str | None, optional |
37
|
|
|
Scoring method to evaluate the predictions on the test set. |
38
|
|
|
- n_jobs: int, optional |
39
|
|
|
Number of jobs to run in parallel, default is 1. |
40
|
|
|
- random_state: int | None, optional |
41
|
|
|
Random seed for reproducibility. |
42
|
|
|
- refit: bool, optional |
43
|
|
|
Refit the best estimator with the entire dataset, default is True. |
44
|
|
|
- cv: int | "BaseCrossValidator" | Iterable | None, optional |
45
|
|
|
Determines the cross-validation splitting strategy. |
46
|
|
|
|
47
|
|
|
Methods: |
48
|
|
|
- fit(X, y, **fit_params) |
49
|
|
|
Fit the estimator and tune hyperparameters. |
50
|
|
|
- score(X, y, **params) |
51
|
|
|
Return the score of the best estimator on the input data. |
52
|
|
|
""" |
53
|
|
|
|
54
|
|
|
_required_parameters = ["estimator", "optimizer", "params_config"] |
55
|
|
|
|
56
|
|
|
def __init__( |
57
|
|
|
self, |
58
|
|
|
estimator: "SklearnBaseEstimator", |
59
|
|
|
params_config: Dict[str, list], |
60
|
|
|
optimizer: Union[str, Type[RandomSearchOptimizer]] = "default", |
61
|
|
|
n_iter: int = 100, |
62
|
|
|
*, |
63
|
|
|
scoring: Callable | str | None = None, |
64
|
|
|
n_jobs: int = 1, |
65
|
|
|
random_state: int | None = None, |
66
|
|
|
refit: bool = True, |
67
|
|
|
cv=None, |
68
|
|
|
): |
69
|
|
|
super().__init__() |
70
|
|
|
|
71
|
|
|
self.estimator = estimator |
72
|
|
|
self.params_config = params_config |
73
|
|
|
self.optimizer = optimizer |
74
|
|
|
self.n_iter = n_iter |
75
|
|
|
self.scoring = scoring |
76
|
|
|
self.n_jobs = n_jobs |
77
|
|
|
self.random_state = random_state |
78
|
|
|
self.refit = refit |
79
|
|
|
self.cv = cv |
80
|
|
|
|
81
|
|
|
def _refit(self, X, y=None, **fit_params): |
82
|
|
|
self.best_estimator_ = clone(self.estimator).set_params( |
83
|
|
|
**clone(self.best_params_, safe=False) |
84
|
|
|
) |
85
|
|
|
|
86
|
|
|
self.best_estimator_.fit(X, y, **fit_params) |
87
|
|
|
return self |
88
|
|
|
|
89
|
|
|
@Checks.verify_fit |
90
|
|
|
def fit(self, X, y, **fit_params): |
91
|
|
|
X, y = indexable(X, y) |
92
|
|
|
X, y = self._validate_data(X, y) |
93
|
|
|
|
94
|
|
|
fit_params = _check_method_params(X, params=fit_params) |
95
|
|
|
self.scorer_ = check_scoring(self.estimator, scoring=self.scoring) |
96
|
|
|
|
97
|
|
|
objective_function_adapter = ObjectiveFunctionAdapter( |
98
|
|
|
self.estimator, |
99
|
|
|
) |
100
|
|
|
objective_function_adapter.add_dataset(X, y) |
101
|
|
|
objective_function_adapter.add_validation(self.scorer_, self.cv) |
102
|
|
|
objective_function = objective_function_adapter.objective_function |
103
|
|
|
|
104
|
|
|
hyper = Hyperactive(verbosity=False) |
105
|
|
|
hyper.add_search( |
106
|
|
|
objective_function, |
107
|
|
|
search_space=self.params_config, |
108
|
|
|
optimizer=self.optimizer, |
109
|
|
|
n_iter=self.n_iter, |
110
|
|
|
n_jobs=self.n_jobs, |
111
|
|
|
random_state=self.random_state, |
112
|
|
|
) |
113
|
|
|
hyper.run() |
114
|
|
|
self.best_params_ = hyper.best_para(objective_function) |
115
|
|
|
self.best_score_ = hyper.best_score(objective_function) |
116
|
|
|
|
117
|
|
|
self.best_params_ = hyper.best_para(objective_function) |
118
|
|
|
self.best_score_ = hyper.best_score(objective_function) |
119
|
|
|
self.search_data_ = hyper.search_data(objective_function) |
120
|
|
|
|
121
|
|
|
if self.refit: |
122
|
|
|
self._refit(X, y, **fit_params) |
123
|
|
|
|
124
|
|
|
return self |
125
|
|
|
|
126
|
|
|
def score(self, X, y=None, **params): |
127
|
|
|
return self.scorer_(self.best_estimator_, X, y, **params) |
128
|
|
|
|
129
|
|
|
@property |
130
|
|
|
def fit_successful(self): |
131
|
|
|
self._fit_successful |
132
|
|
|
|