Passed
Push — master ( 199200...2b2f7b )
by Simon
04:29
created

HyperOptimizer.convert_results2hyper()   A

Complexity

Conditions 2

Size

Total Lines 24
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 17
nop 1
dl 0
loc 24
rs 9.55
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import numpy as np
6
7
from .objective_function import ObjectiveFunction
8
from .hyper_gradient_conv import HyperGradientConv
9
from .optimizer_attributes import OptimizerAttributes
10
11
12
class HyperOptimizer(OptimizerAttributes):
13
    def __init__(self, **opt_params):
14
        super().__init__()
15
        self.opt_params = opt_params
16
17 View Code Duplication
    def setup_search(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
18
        self,
19
        objective_function,
20
        s_space,
21
        n_iter,
22
        initialize,
23
        pass_through,
24
        callbacks,
25
        catch,
26
        max_score,
27
        early_stopping,
28
        random_state,
29
        memory,
30
        memory_warm_start,
31
        verbosity,
32
    ):
33
        self.objective_function = objective_function
34
        self.s_space = s_space
35
        self.n_iter = n_iter
36
37
        self.initialize = initialize
38
        self.pass_through = pass_through
39
        self.callbacks = callbacks
40
        self.catch = catch
41
        self.max_score = max_score
42
        self.early_stopping = early_stopping
43
        self.random_state = random_state
44
        self.memory = memory
45
        self.memory_warm_start = memory_warm_start
46
        self.verbosity = verbosity
47
48
        if "progress_bar" in self.verbosity:
49
            self.verbosity = ["progress_bar"]
50
        else:
51
            self.verbosity = []
52
53
    def convert_results2hyper(self):
54
        self.eval_times = np.array(self.gfo_optimizer.eval_times).sum()
55
        self.iter_times = np.array(self.gfo_optimizer.iter_times).sum()
56
57
        if self.gfo_optimizer.best_para is not None:
58
            value = self.hg_conv.para2value(self.gfo_optimizer.best_para)
59
            position = self.hg_conv.position2value(value)
60
            best_para = self.hg_conv.value2para(position)
61
62
            self.best_para = best_para
63
        else:
64
            self.best_para = None
65
66
        self.best_score = self.gfo_optimizer.best_score
67
        self.positions = self.gfo_optimizer.search_data
68
69
        self.search_data = self.hg_conv.positions2results(self.positions)
70
71
        results_dd = self.gfo_optimizer.search_data.drop_duplicates(
72
            subset=self.s_space.dim_keys, keep="first"
73
        )
74
        self.memory_values_df = results_dd[
75
            self.s_space.dim_keys + ["score"]
76
        ].reset_index(drop=True)
77
78
    def _setup_process(self, nth_process):
79
        self.nth_process = nth_process
80
81
        self.hg_conv = HyperGradientConv(self.s_space)
82
83
        initialize = self.hg_conv.conv_initialize(self.initialize)
84
        search_space_positions = self.s_space.positions
85
86
        # conv warm start for smbo from values into positions
87
        if "warm_start_smbo" in self.opt_params:
88
            self.opt_params["warm_start_smbo"] = self.hg_conv.conv_memory_warm_start(
89
                self.opt_params["warm_start_smbo"]
90
            )
91
92
        self.gfo_optimizer = self.optimizer_class(
93
            search_space=search_space_positions,
94
            initialize=initialize,
95
            random_state=self.random_state,
96
            nth_process=nth_process,
97
            **self.opt_params
98
        )
99
100
        self.conv = self.gfo_optimizer.conv
101
102
    def search(self, nth_process, p_bar):
103
        self._setup_process(nth_process)
104
105
        gfo_wrapper_model = ObjectiveFunction(
106
            objective_function=self.objective_function,
107
            optimizer=self.gfo_optimizer,
108
            callbacks=self.callbacks,
109
            catch=self.catch,
110
            nth_process=self.nth_process,
111
        )
112
        gfo_wrapper_model.pass_through = self.pass_through
113
114
        memory_warm_start = self.hg_conv.conv_memory_warm_start(self.memory_warm_start)
115
116
        gfo_objective_function = gfo_wrapper_model(self.s_space())
117
118
        self.gfo_optimizer.init_search(
119
            gfo_objective_function,
120
            self.n_iter,
121
            self.max_time,
122
            self.max_score,
123
            self.early_stopping,
124
            self.memory,
125
            memory_warm_start,
126
            False,
127
        )
128
        for nth_iter in range(self.n_iter):
129
            p_bar.set_description(
130
                "["
131
                + str(nth_process)
132
                + "] "
133
                + str(self.objective_function.__name__)
134
                + " ("
135
                + self.optimizer_class.name
136
                + ")",
137
            )
138
139
            self.gfo_optimizer.search_step(nth_iter)
140
            if self.gfo_optimizer.stop.check():
141
                break
142
143
            p_bar.set_postfix(
144
                best_score=str(gfo_wrapper_model.optimizer.score_best),
145
                best_pos=str(gfo_wrapper_model.optimizer.pos_best),
146
                best_iter=str(gfo_wrapper_model.optimizer.p_bar._best_since_iter),
147
            )
148
149
            p_bar.update(1)
150
            p_bar.refresh()
151
152
        self.gfo_optimizer.finish_search()
153
154
        self.convert_results2hyper()
155
156
        self._add_result_attributes(
157
            self.best_para,
158
            self.best_score,
159
            self.gfo_optimizer.p_bar._best_since_iter,
160
            self.eval_times,
161
            self.iter_times,
162
            self.search_data,
163
            self.gfo_optimizer.random_seed,
164
        )
165