HyperOptimizer._setup_process()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 31
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 20
nop 2
dl 0
loc 31
rs 9.4
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import numpy as np
6
7
from .objective_function import ObjectiveFunction
8
from .hyper_gradient_conv import HyperGradientConv
9
from .optimizer_attributes import OptimizerAttributes
10
from .constraint import Constraint
11
12
13
class HyperOptimizer(OptimizerAttributes):
14
    def __init__(self, **opt_params):
15
        super().__init__()
16
        self.opt_params = opt_params
17
18 View Code Duplication
    def setup_search(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
19
        self,
20
        objective_function,
21
        s_space,
22
        n_iter,
23
        initialize,
24
        constraints,
25
        pass_through,
26
        callbacks,
27
        catch,
28
        max_score,
29
        early_stopping,
30
        random_state,
31
        memory,
32
        memory_warm_start,
33
        verbosity,
34
    ):
35
        self.objective_function = objective_function
36
        self.s_space = s_space
37
        self.n_iter = n_iter
38
39
        self.initialize = initialize
40
        self.constraints = constraints
41
        self.pass_through = pass_through
42
        self.callbacks = callbacks
43
        self.catch = catch
44
        self.max_score = max_score
45
        self.early_stopping = early_stopping
46
        self.random_state = random_state
47
        self.memory = memory
48
        self.memory_warm_start = memory_warm_start
49
        self.verbosity = verbosity
50
51
        if "progress_bar" in self.verbosity:
52
            self.verbosity = ["progress_bar"]
53
        else:
54
            self.verbosity = []
55
56
    def convert_results2hyper(self):
57
        self.eval_times = sum(self.gfo_optimizer.eval_times)
58
        self.iter_times = sum(self.gfo_optimizer.iter_times)
59
60
        if self.gfo_optimizer.best_para is not None:
61
            value = self.hg_conv.para2value(self.gfo_optimizer.best_para)
62
            position = self.hg_conv.position2value(value)
63
            best_para = self.hg_conv.value2para(position)
64
            self.best_para = best_para
65
        else:
66
            self.best_para = None
67
68
        self.best_score = self.gfo_optimizer.best_score
69
        self.positions = self.gfo_optimizer.search_data
70
        self.search_data = self.hg_conv.positions2results(self.positions)
71
72
        results_dd = self.gfo_optimizer.search_data.drop_duplicates(
73
            subset=self.s_space.dim_keys, keep="first"
74
        )
75
        self.memory_values_df = results_dd[
76
            self.s_space.dim_keys + ["score"]
77
        ].reset_index(drop=True)
78
79
    def _setup_process(self, nth_process):
80
        self.nth_process = nth_process
81
82
        self.hg_conv = HyperGradientConv(self.s_space)
83
84
        initialize = self.hg_conv.conv_initialize(self.initialize)
85
        search_space_positions = self.s_space.positions
86
87
        # conv warm start for smbo from values into positions
88
        if "warm_start_smbo" in self.opt_params:
89
            self.opt_params["warm_start_smbo"] = (
90
                self.hg_conv.conv_memory_warm_start(
91
                    self.opt_params["warm_start_smbo"]
92
                )
93
            )
94
95
        gfo_constraints = [
96
            Constraint(constraint, self.s_space)
97
            for constraint in self.constraints
98
        ]
99
100
        self.gfo_optimizer = self.optimizer_class(
101
            search_space=search_space_positions,
102
            initialize=initialize,
103
            constraints=gfo_constraints,
104
            random_state=self.random_state,
105
            nth_process=nth_process,
106
            **self.opt_params,
107
        )
108
109
        self.conv = self.gfo_optimizer.conv
110
111
    def search(self, nth_process, p_bar):
112
        self._setup_process(nth_process)
113
114
        gfo_wrapper_model = ObjectiveFunction(
115
            objective_function=self.objective_function,
116
            optimizer=self.gfo_optimizer,
117
            callbacks=self.callbacks,
118
            catch=self.catch,
119
            nth_process=self.nth_process,
120
        )
121
        gfo_wrapper_model.pass_through = self.pass_through
122
123
        memory_warm_start = self.hg_conv.conv_memory_warm_start(
124
            self.memory_warm_start
125
        )
126
127
        gfo_objective_function = gfo_wrapper_model(self.s_space())
128
129
        self.gfo_optimizer.init_search(
130
            gfo_objective_function,
131
            self.n_iter,
132
            self.max_time,
133
            self.max_score,
134
            self.early_stopping,
135
            self.memory,
136
            memory_warm_start,
137
            False,
138
        )
139
        for nth_iter in range(self.n_iter):
140
            if p_bar:
141
                p_bar.set_description(
142
                    "["
143
                    + str(nth_process)
144
                    + "] "
145
                    + str(self.objective_function.__name__)
146
                    + " ("
147
                    + self.optimizer_class.name
148
                    + ")",
149
                )
150
151
            self.gfo_optimizer.search_step(nth_iter)
152
            if self.gfo_optimizer.stop.check():
153
                break
154
155
            if p_bar:
156
                p_bar.set_postfix(
157
                    best_score=str(gfo_wrapper_model.optimizer.score_best),
158
                    best_pos=str(gfo_wrapper_model.optimizer.pos_best),
159
                    best_iter=str(
160
                        gfo_wrapper_model.optimizer.p_bar._best_since_iter
161
                    ),
162
                )
163
164
                p_bar.update(1)
165
                p_bar.refresh()
166
167
        self.gfo_optimizer.finish_search()
168
169
        self.convert_results2hyper()
170
171
        self._add_result_attributes(
172
            self.best_para,
173
            self.best_score,
174
            self.gfo_optimizer.p_bar._best_since_iter,
175
            self.eval_times,
176
            self.iter_times,
177
            self.search_data,
178
            self.gfo_optimizer.random_seed,
179
        )
180