Passed
Pull Request — master (#101)
by Simon
01:48
created

BaseOptimizer.add_search()   B

Complexity

Conditions 3

Size

Total Lines 72
Code Lines 43

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 43
nop 15
dl 0
loc 72
rs 8.8478
c 0
b 0
f 0

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""Base class for optimizer."""
2
3
import numpy as np
4
from typing import Union, List, Dict, Type
5
import copy
6
import multiprocessing as mp
7
import pandas as pd
8
9
from .backend_stuff.search_space import SearchSpace
10
from .backend_stuff.run_search import run_search
11
from .hyper_optimizer import HyperOptimizer
12
from .backend_stuff.results import Results
13
from .backend_stuff.print_results import PrintResults
14
15
from skbase.base import BaseObject
16
17
18
class BaseOptimizer(BaseObject):
19
    """Base class for optimizer."""
20
21
    opt_pros = {}
22
23
    def __init__(self, optimizer_class, opt_params):
24
        super().__init__()
25
        self.opt_params = opt_params
26
        self.hyper_optimizer = HyperOptimizer(optimizer_class, opt_params)
27
28
    @staticmethod
29
    def _default_search_id(search_id, objective_function):
30
        if not search_id:
31
            search_id = objective_function.__name__
32
        return search_id
33
34
    @staticmethod
35
    def check_list(search_space):
36
        for key in search_space.keys():
37
            search_dim = search_space[key]
38
39
            error_msg = "Value in '{}' of search space dictionary must be of type list".format(
40
                key
41
            )
42
            if not isinstance(search_dim, list):
43
                print("Warning", error_msg)
44
                # raise ValueError(error_msg)
45
46
    def add_search(
47
        self,
48
        experiment: callable,
49
        search_space: Dict[str, list],
50
        n_iter: int,
51
        search_id=None,
52
        n_jobs: int = 1,
53
        verbosity: list = ["progress_bar", "print_results", "print_times"],
54
        initialize: Dict[str, int] = {"grid": 4, "random": 2, "vertices": 4},
55
        constraints: List[callable] = None,
56
        pass_through: Dict = None,
57
        max_score: float = None,
58
        early_stopping: Dict = None,
59
        random_state: int = None,
60
        memory: Union[str, bool] = "share",
61
        memory_warm_start: pd.DataFrame = None,
62
    ):
63
        """
64
        Add a new optimization search process with specified parameters.
65
66
        Parameters:
67
        - experiment: Experiment class containing the objective-function to optimize.
68
        - search_space: Dictionary defining the search space for optimization.
69
        - n_iter: Number of iterations for the optimization process.
70
        - search_id: Identifier for the search process (default: None).
71
        - n_jobs: Number of parallel jobs to run (default: 1).
72
        - initialize: Dictionary specifying initialization parameters (default: {"grid": 4, "random": 2, "vertices": 4}).
73
        - constraints: List of constraint functions (default: None).
74
        - pass_through: Dictionary of additional parameters to pass through (default: None).
75
        - callbacks: Dictionary of callback functions (default: None).
76
        - catch: Dictionary of exceptions to catch during optimization (default: None).
77
        - max_score: Maximum score to achieve (default: None).
78
        - early_stopping: Dictionary specifying early stopping criteria (default: None).
79
        - random_state: Seed for random number generation (default: None).
80
        - memory: Option to share memory between processes (default: "share").
81
        - memory_warm_start: DataFrame containing warm start memory (default: None).
82
        """
83
84
        self.check_list(search_space)
85
86
        constraints = constraints or []
87
        pass_through = pass_through or {}
88
        early_stopping = early_stopping or {}
89
90
        search_id = self._default_search_id(
91
            search_id, experiment.objective_function
92
        )
93
        s_space = SearchSpace(search_space)
94
        self.verbosity = verbosity
95
96
        self.hyper_optimizer.setup_search(
97
            experiment=experiment,
98
            s_space=s_space,
99
            n_iter=n_iter,
100
            initialize=initialize,
101
            constraints=constraints,
102
            pass_through=pass_through,
103
            callbacks=experiment.callbacks,
104
            catch=experiment.catch,
105
            max_score=max_score,
106
            early_stopping=early_stopping,
107
            random_state=random_state,
108
            memory=memory,
109
            memory_warm_start=memory_warm_start,
110
            verbosity=verbosity,
111
        )
112
113
        n_jobs = mp.cpu_count() if n_jobs == -1 else n_jobs
114
115
        for _ in range(n_jobs):
116
            nth_process = len(self.opt_pros)
117
            self.opt_pros[nth_process] = self.hyper_optimizer
118
119
    def _print_info(self):
120
        print_res = PrintResults(self.opt_pros, self.verbosity)
121
122
        if self.verbosity:
123
            for _ in range(len(self.opt_pros)):
124
                print("")
125
126
        for results in self.results_list:
127
            nth_process = results["nth_process"]
128
            print_res.print_process(results, nth_process)
129
130
    def run(
131
        self,
132
        max_time=None,
133
        distribution: str = "multiprocessing",
134
        n_processes: Union[str, int] = "auto",
135
    ):
136
        for opt in self.opt_pros.values():
137
            opt.max_time = max_time
138
139
        self.results_list = run_search(self.opt_pros, distribution, n_processes)
140
141
        self.results_ = Results(self.results_list, self.opt_pros)
142
143
        self._print_info()
144
145
    def best_para(self, id_):
146
        """
147
        Retrieve the best parameters for a specific ID from the results.
148
149
        Parameters:
150
        - id_ (int): The ID of the parameters to retrieve.
151
152
        Returns:
153
        - Union[Dict[str, Union[int, float]], None]: The best parameters for the specified ID if found, otherwise None.
154
155
        Raises:
156
        - ValueError: If the objective function name is not recognized.
157
        """
158
159
        return self.results_.best_para(id_)
160
161
    def best_score(self, id_):
162
        """
163
        Return the best score for a specific ID from the results.
164
165
        Parameters:
166
        - id_ (int): The ID for which the best score is requested.
167
        """
168
169
        return self.results_.best_score(id_)
170
171
    def search_data(self, id_, times=False):
172
        """
173
        Retrieve search data for a specific ID from the results. Optionally exclude evaluation and iteration times if 'times' is set to False.
174
175
        Parameters:
176
        - id_ (int): The ID of the search data to retrieve.
177
        - times (bool, optional): Whether to exclude evaluation and iteration times. Defaults to False.
178
179
        Returns:
180
        - pd.DataFrame: The search data for the specified ID.
181
        """
182
183
        search_data_ = self.results_.search_data(id_.objective_function)
184
185
        if times == False:
186
            search_data_.drop(
187
                labels=["eval_times", "iter_times"],
188
                axis=1,
189
                inplace=True,
190
                errors="ignore",
191
            )
192
        return search_data_
193