Passed
Pull Request — master (#101)
by Simon
01:48
created

hyperactive.optimizers._optimizer_api   A

Complexity

Total Complexity 19

Size/Duplication

Total Lines 193
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 105
dl 0
loc 193
rs 10
c 0
b 0
f 0
wmc 19

9 Methods

Rating   Name   Duplication   Size   Complexity  
A BaseOptimizer.check_list() 0 10 3
A BaseOptimizer._default_search_id() 0 5 2
A BaseOptimizer.__init__() 0 4 1
A BaseOptimizer.best_score() 0 9 1
A BaseOptimizer.run() 0 14 2
A BaseOptimizer.best_para() 0 15 1
A BaseOptimizer.search_data() 0 22 2
B BaseOptimizer.add_search() 0 72 3
A BaseOptimizer._print_info() 0 10 4
1
"""Base class for optimizer."""
2
3
import numpy as np
4
from typing import Union, List, Dict, Type
5
import copy
6
import multiprocessing as mp
7
import pandas as pd
8
9
from .backend_stuff.search_space import SearchSpace
10
from .backend_stuff.run_search import run_search
11
from .hyper_optimizer import HyperOptimizer
12
from .backend_stuff.results import Results
13
from .backend_stuff.print_results import PrintResults
14
15
from skbase.base import BaseObject
16
17
18
class BaseOptimizer(BaseObject):
19
    """Base class for optimizer."""
20
21
    opt_pros = {}
22
23
    def __init__(self, optimizer_class, opt_params):
24
        super().__init__()
25
        self.opt_params = opt_params
26
        self.hyper_optimizer = HyperOptimizer(optimizer_class, opt_params)
27
28
    @staticmethod
29
    def _default_search_id(search_id, objective_function):
30
        if not search_id:
31
            search_id = objective_function.__name__
32
        return search_id
33
34
    @staticmethod
35
    def check_list(search_space):
36
        for key in search_space.keys():
37
            search_dim = search_space[key]
38
39
            error_msg = "Value in '{}' of search space dictionary must be of type list".format(
40
                key
41
            )
42
            if not isinstance(search_dim, list):
43
                print("Warning", error_msg)
44
                # raise ValueError(error_msg)
45
46
    def add_search(
47
        self,
48
        experiment: callable,
49
        search_space: Dict[str, list],
50
        n_iter: int,
51
        search_id=None,
52
        n_jobs: int = 1,
53
        verbosity: list = ["progress_bar", "print_results", "print_times"],
54
        initialize: Dict[str, int] = {"grid": 4, "random": 2, "vertices": 4},
55
        constraints: List[callable] = None,
56
        pass_through: Dict = None,
57
        max_score: float = None,
58
        early_stopping: Dict = None,
59
        random_state: int = None,
60
        memory: Union[str, bool] = "share",
61
        memory_warm_start: pd.DataFrame = None,
62
    ):
63
        """
64
        Add a new optimization search process with specified parameters.
65
66
        Parameters:
67
        - experiment: Experiment class containing the objective-function to optimize.
68
        - search_space: Dictionary defining the search space for optimization.
69
        - n_iter: Number of iterations for the optimization process.
70
        - search_id: Identifier for the search process (default: None).
71
        - n_jobs: Number of parallel jobs to run (default: 1).
72
        - initialize: Dictionary specifying initialization parameters (default: {"grid": 4, "random": 2, "vertices": 4}).
73
        - constraints: List of constraint functions (default: None).
74
        - pass_through: Dictionary of additional parameters to pass through (default: None).
75
        - callbacks: Dictionary of callback functions (default: None).
76
        - catch: Dictionary of exceptions to catch during optimization (default: None).
77
        - max_score: Maximum score to achieve (default: None).
78
        - early_stopping: Dictionary specifying early stopping criteria (default: None).
79
        - random_state: Seed for random number generation (default: None).
80
        - memory: Option to share memory between processes (default: "share").
81
        - memory_warm_start: DataFrame containing warm start memory (default: None).
82
        """
83
84
        self.check_list(search_space)
85
86
        constraints = constraints or []
87
        pass_through = pass_through or {}
88
        early_stopping = early_stopping or {}
89
90
        search_id = self._default_search_id(
91
            search_id, experiment.objective_function
92
        )
93
        s_space = SearchSpace(search_space)
94
        self.verbosity = verbosity
95
96
        self.hyper_optimizer.setup_search(
97
            experiment=experiment,
98
            s_space=s_space,
99
            n_iter=n_iter,
100
            initialize=initialize,
101
            constraints=constraints,
102
            pass_through=pass_through,
103
            callbacks=experiment.callbacks,
104
            catch=experiment.catch,
105
            max_score=max_score,
106
            early_stopping=early_stopping,
107
            random_state=random_state,
108
            memory=memory,
109
            memory_warm_start=memory_warm_start,
110
            verbosity=verbosity,
111
        )
112
113
        n_jobs = mp.cpu_count() if n_jobs == -1 else n_jobs
114
115
        for _ in range(n_jobs):
116
            nth_process = len(self.opt_pros)
117
            self.opt_pros[nth_process] = self.hyper_optimizer
118
119
    def _print_info(self):
120
        print_res = PrintResults(self.opt_pros, self.verbosity)
121
122
        if self.verbosity:
123
            for _ in range(len(self.opt_pros)):
124
                print("")
125
126
        for results in self.results_list:
127
            nth_process = results["nth_process"]
128
            print_res.print_process(results, nth_process)
129
130
    def run(
131
        self,
132
        max_time=None,
133
        distribution: str = "multiprocessing",
134
        n_processes: Union[str, int] = "auto",
135
    ):
136
        for opt in self.opt_pros.values():
137
            opt.max_time = max_time
138
139
        self.results_list = run_search(self.opt_pros, distribution, n_processes)
140
141
        self.results_ = Results(self.results_list, self.opt_pros)
142
143
        self._print_info()
144
145
    def best_para(self, id_):
146
        """
147
        Retrieve the best parameters for a specific ID from the results.
148
149
        Parameters:
150
        - id_ (int): The ID of the parameters to retrieve.
151
152
        Returns:
153
        - Union[Dict[str, Union[int, float]], None]: The best parameters for the specified ID if found, otherwise None.
154
155
        Raises:
156
        - ValueError: If the objective function name is not recognized.
157
        """
158
159
        return self.results_.best_para(id_)
160
161
    def best_score(self, id_):
162
        """
163
        Return the best score for a specific ID from the results.
164
165
        Parameters:
166
        - id_ (int): The ID for which the best score is requested.
167
        """
168
169
        return self.results_.best_score(id_)
170
171
    def search_data(self, id_, times=False):
172
        """
173
        Retrieve search data for a specific ID from the results. Optionally exclude evaluation and iteration times if 'times' is set to False.
174
175
        Parameters:
176
        - id_ (int): The ID of the search data to retrieve.
177
        - times (bool, optional): Whether to exclude evaluation and iteration times. Defaults to False.
178
179
        Returns:
180
        - pd.DataFrame: The search data for the specified ID.
181
        """
182
183
        search_data_ = self.results_.search_data(id_.objective_function)
184
185
        if times == False:
186
            search_data_.drop(
187
                labels=["eval_times", "iter_times"],
188
                axis=1,
189
                inplace=True,
190
                errors="ignore",
191
            )
192
        return search_data_
193