Total Complexity | 4 |
Total Lines | 57 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | """Base class for optimizer.""" |
||
2 | # copyright: hyperactive developers, MIT License (see LICENSE file) |
||
3 | |||
4 | from skbase.base import BaseObject |
||
5 | |||
6 | |||
7 | class BaseOptimizer(BaseObject): |
||
8 | """Base class for optimizer.""" |
||
9 | |||
10 | _tags = { |
||
11 | "object_type": "optimizer", |
||
12 | "python_dependencies": None, |
||
13 | } |
||
14 | |||
15 | def __init__(self): |
||
16 | super().__init__() |
||
17 | assert hasattr(self, "experiment"), "Optimizer must have an experiment." |
||
18 | search_config = self.get_params() |
||
19 | self._experiment = search_config.pop("experiment", None) |
||
20 | |||
21 | def get_search_config(self): |
||
22 | """Get the search configuration. |
||
23 | |||
24 | Returns |
||
25 | ------- |
||
26 | dict with str keys |
||
27 | The search configuration dictionary. |
||
28 | """ |
||
29 | search_config = self.get_params(deep=False) |
||
30 | search_config.pop("experiment", None) |
||
31 | return search_config |
||
32 | |||
33 | def get_experiment(self): |
||
34 | """Get the experiment. |
||
35 | |||
36 | Returns |
||
37 | ------- |
||
38 | BaseExperiment |
||
39 | The experiment to optimize parameters for. |
||
40 | """ |
||
41 | return self._experiment |
||
42 | |||
43 | def run(self): |
||
44 | """Run the optimization search process. |
||
45 | |||
46 | Returns |
||
47 | ------- |
||
48 | best_params : dict |
||
49 | The best parameters found during the optimization process. |
||
50 | """ |
||
51 | experiment = self.get_experiment() |
||
52 | search_config = self.get_search_config() |
||
53 | |||
54 | best_params = self._run(experiment, **search_config) |
||
55 | self.best_params_ = best_params |
||
56 | return best_params |
||
57 |