Passed
Pull Request — master (#101)
by Simon
02:38 queued 01:05
created

hyperactive.optimizers._optimizer_api   A

Complexity

Total Complexity 15

Size/Duplication

Total Lines 148
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 96
dl 0
loc 148
rs 10
c 0
b 0
f 0
wmc 15

6 Methods

Rating   Name   Duplication   Size   Complexity  
A BaseOptimizer.run() 0 14 2
A BaseOptimizer.check_list() 0 10 3
A BaseOptimizer._default_search_id() 0 5 2
A BaseOptimizer.__init__() 0 4 1
B BaseOptimizer.add_search() 0 76 3
A BaseOptimizer._print_info() 0 10 4
1
"""Base class for optimizer."""
2
3
import numpy as np
4
from typing import Union, List, Dict, Type
5
import copy
6
import multiprocessing as mp
7
import pandas as pd
8
9
from .backend_stuff.search_space import SearchSpace
10
from .backend_stuff.run_search import run_search
11
from .hyper_optimizer import HyperOptimizer
12
from .backend_stuff.results import Results
13
from .backend_stuff.print_results import PrintResults
14
15
from skbase.base import BaseObject
16
17
18
class BaseOptimizer(BaseObject):
19
    """Base class for optimizer."""
20
21
    opt_pros = {}
22
23
    def __init__(self, optimizer_class, opt_params):
24
        super().__init__()
25
        self.opt_params = opt_params
26
        self.hyper_optimizer = HyperOptimizer(optimizer_class, opt_params)
27
28
    @staticmethod
29
    def _default_search_id(search_id, objective_function):
30
        if not search_id:
31
            search_id = objective_function.__name__
32
        return search_id
33
34
    @staticmethod
35
    def check_list(search_space):
36
        for key in search_space.keys():
37
            search_dim = search_space[key]
38
39
            error_msg = "Value in '{}' of search space dictionary must be of type list".format(
40
                key
41
            )
42
            if not isinstance(search_dim, list):
43
                print("Warning", error_msg)
44
                # raise ValueError(error_msg)
45
46
    def add_search(
47
        self,
48
        experiment: callable,
49
        search_space: Dict[str, list],
50
        n_iter: int,
51
        search_id=None,
52
        n_jobs: int = 1,
53
        verbosity: list = ["progress_bar", "print_results", "print_times"],
54
        initialize: Dict[str, int] = {"grid": 4, "random": 2, "vertices": 4},
55
        constraints: List[callable] = None,
56
        pass_through: Dict = None,
57
        callbacks: Dict[str, callable] = None,
58
        catch: Dict = None,
59
        max_score: float = None,
60
        early_stopping: Dict = None,
61
        random_state: int = None,
62
        memory: Union[str, bool] = "share",
63
        memory_warm_start: pd.DataFrame = None,
64
    ):
65
        """
66
        Add a new optimization search process with specified parameters.
67
68
        Parameters:
69
        - objective_function: The objective function to optimize.
70
        - search_space: Dictionary defining the search space for optimization.
71
        - n_iter: Number of iterations for the optimization process.
72
        - search_id: Identifier for the search process (default: None).
73
        - n_jobs: Number of parallel jobs to run (default: 1).
74
        - initialize: Dictionary specifying initialization parameters (default: {"grid": 4, "random": 2, "vertices": 4}).
75
        - constraints: List of constraint functions (default: None).
76
        - pass_through: Dictionary of additional parameters to pass through (default: None).
77
        - callbacks: Dictionary of callback functions (default: None).
78
        - catch: Dictionary of exceptions to catch during optimization (default: None).
79
        - max_score: Maximum score to achieve (default: None).
80
        - early_stopping: Dictionary specifying early stopping criteria (default: None).
81
        - random_state: Seed for random number generation (default: None).
82
        - memory: Option to share memory between processes (default: "share").
83
        - memory_warm_start: DataFrame containing warm start memory (default: None).
84
        """
85
86
        objective_function = experiment._score
87
88
        self.check_list(search_space)
89
90
        constraints = constraints or []
91
        pass_through = pass_through or {}
92
        callbacks = callbacks or {}
93
        catch = catch or {}
94
        early_stopping = early_stopping or {}
95
96
        search_id = self._default_search_id(search_id, objective_function)
97
        s_space = SearchSpace(search_space)
98
        self.verbosity = verbosity
99
100
        self.hyper_optimizer.setup_search(
101
            objective_function=objective_function,
102
            s_space=s_space,
103
            n_iter=n_iter,
104
            initialize=initialize,
105
            constraints=constraints,
106
            pass_through=pass_through,
107
            callbacks=callbacks,
108
            catch=catch,
109
            max_score=max_score,
110
            early_stopping=early_stopping,
111
            random_state=random_state,
112
            memory=memory,
113
            memory_warm_start=memory_warm_start,
114
            verbosity=verbosity,
115
        )
116
117
        n_jobs = mp.cpu_count() if n_jobs == -1 else n_jobs
118
119
        for _ in range(n_jobs):
120
            nth_process = len(self.opt_pros)
121
            self.opt_pros[nth_process] = self.hyper_optimizer
122
123
    def _print_info(self):
124
        print_res = PrintResults(self.opt_pros, self.verbosity)
125
126
        if self.verbosity:
127
            for _ in range(len(self.opt_pros)):
128
                print("")
129
130
        for results in self.results_list:
131
            nth_process = results["nth_process"]
132
            print_res.print_process(results, nth_process)
133
134
    def run(
135
        self,
136
        max_time=None,
137
        distribution: str = "multiprocessing",
138
        n_processes: Union[str, int] = "auto",
139
    ):
140
        for opt in self.opt_pros.values():
141
            opt.max_time = max_time
142
143
        self.results_list = run_search(self.opt_pros, distribution, n_processes)
144
145
        self.results_ = Results(self.results_list, self.opt_pros)
146
147
        self._print_info()
148