Passed
Pull Request — master (#101)
by Simon
02:38 queued 01:05
created

BaseOptimizer.add_search()   B

Complexity

Conditions 3

Size

Total Lines 76
Code Lines 47

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 47
nop 17
dl 0
loc 76
rs 8.7345
c 0
b 0
f 0

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""Base class for optimizer."""
2
3
import numpy as np
4
from typing import Union, List, Dict, Type
5
import copy
6
import multiprocessing as mp
7
import pandas as pd
8
9
from .backend_stuff.search_space import SearchSpace
10
from .backend_stuff.run_search import run_search
11
from .hyper_optimizer import HyperOptimizer
12
from .backend_stuff.results import Results
13
from .backend_stuff.print_results import PrintResults
14
15
from skbase.base import BaseObject
16
17
18
class BaseOptimizer(BaseObject):
19
    """Base class for optimizer."""
20
21
    opt_pros = {}
22
23
    def __init__(self, optimizer_class, opt_params):
24
        super().__init__()
25
        self.opt_params = opt_params
26
        self.hyper_optimizer = HyperOptimizer(optimizer_class, opt_params)
27
28
    @staticmethod
29
    def _default_search_id(search_id, objective_function):
30
        if not search_id:
31
            search_id = objective_function.__name__
32
        return search_id
33
34
    @staticmethod
35
    def check_list(search_space):
36
        for key in search_space.keys():
37
            search_dim = search_space[key]
38
39
            error_msg = "Value in '{}' of search space dictionary must be of type list".format(
40
                key
41
            )
42
            if not isinstance(search_dim, list):
43
                print("Warning", error_msg)
44
                # raise ValueError(error_msg)
45
46
    def add_search(
47
        self,
48
        experiment: callable,
49
        search_space: Dict[str, list],
50
        n_iter: int,
51
        search_id=None,
52
        n_jobs: int = 1,
53
        verbosity: list = ["progress_bar", "print_results", "print_times"],
54
        initialize: Dict[str, int] = {"grid": 4, "random": 2, "vertices": 4},
55
        constraints: List[callable] = None,
56
        pass_through: Dict = None,
57
        callbacks: Dict[str, callable] = None,
58
        catch: Dict = None,
59
        max_score: float = None,
60
        early_stopping: Dict = None,
61
        random_state: int = None,
62
        memory: Union[str, bool] = "share",
63
        memory_warm_start: pd.DataFrame = None,
64
    ):
65
        """
66
        Add a new optimization search process with specified parameters.
67
68
        Parameters:
69
        - objective_function: The objective function to optimize.
70
        - search_space: Dictionary defining the search space for optimization.
71
        - n_iter: Number of iterations for the optimization process.
72
        - search_id: Identifier for the search process (default: None).
73
        - n_jobs: Number of parallel jobs to run (default: 1).
74
        - initialize: Dictionary specifying initialization parameters (default: {"grid": 4, "random": 2, "vertices": 4}).
75
        - constraints: List of constraint functions (default: None).
76
        - pass_through: Dictionary of additional parameters to pass through (default: None).
77
        - callbacks: Dictionary of callback functions (default: None).
78
        - catch: Dictionary of exceptions to catch during optimization (default: None).
79
        - max_score: Maximum score to achieve (default: None).
80
        - early_stopping: Dictionary specifying early stopping criteria (default: None).
81
        - random_state: Seed for random number generation (default: None).
82
        - memory: Option to share memory between processes (default: "share").
83
        - memory_warm_start: DataFrame containing warm start memory (default: None).
84
        """
85
86
        objective_function = experiment._score
87
88
        self.check_list(search_space)
89
90
        constraints = constraints or []
91
        pass_through = pass_through or {}
92
        callbacks = callbacks or {}
93
        catch = catch or {}
94
        early_stopping = early_stopping or {}
95
96
        search_id = self._default_search_id(search_id, objective_function)
97
        s_space = SearchSpace(search_space)
98
        self.verbosity = verbosity
99
100
        self.hyper_optimizer.setup_search(
101
            objective_function=objective_function,
102
            s_space=s_space,
103
            n_iter=n_iter,
104
            initialize=initialize,
105
            constraints=constraints,
106
            pass_through=pass_through,
107
            callbacks=callbacks,
108
            catch=catch,
109
            max_score=max_score,
110
            early_stopping=early_stopping,
111
            random_state=random_state,
112
            memory=memory,
113
            memory_warm_start=memory_warm_start,
114
            verbosity=verbosity,
115
        )
116
117
        n_jobs = mp.cpu_count() if n_jobs == -1 else n_jobs
118
119
        for _ in range(n_jobs):
120
            nth_process = len(self.opt_pros)
121
            self.opt_pros[nth_process] = self.hyper_optimizer
122
123
    def _print_info(self):
124
        print_res = PrintResults(self.opt_pros, self.verbosity)
125
126
        if self.verbosity:
127
            for _ in range(len(self.opt_pros)):
128
                print("")
129
130
        for results in self.results_list:
131
            nth_process = results["nth_process"]
132
            print_res.print_process(results, nth_process)
133
134
    def run(
135
        self,
136
        max_time=None,
137
        distribution: str = "multiprocessing",
138
        n_processes: Union[str, int] = "auto",
139
    ):
140
        for opt in self.opt_pros.values():
141
            opt.max_time = max_time
142
143
        self.results_list = run_search(self.opt_pros, distribution, n_processes)
144
145
        self.results_ = Results(self.results_list, self.opt_pros)
146
147
        self._print_info()
148