Passed
Push — master ( 0c1576...8c2877 )
by Simon
04:40 queued 12s
created

HyperOptimizer.__init__()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import numpy as np
6
import pandas as pd
7
8
9
from .objective_function import ObjectiveFunction
10
from .hyper_gradient_conv import HyperGradientConv
11
from .base_optimizer import BaseOptimizer
12
13
14
class HyperOptimizer(BaseOptimizer):
15
    def __init__(self, **opt_params):
16
        super().__init__()
17
        self.opt_params = opt_params
18
19
    def setup_search(
20
        self,
21
        objective_function,
22
        s_space,
23
        n_iter,
24
        initialize,
25
        pass_through,
26
        callbacks,
27
        catch,
28
        max_score,
29
        early_stopping,
30
        random_state,
31
        memory,
32
        memory_warm_start,
33
        verbosity,
34
    ):
35
        self.objective_function = objective_function
36
        self.s_space = s_space
37
        self.n_iter = n_iter
38
39
        self.initialize = initialize
40
        self.pass_through = pass_through
41
        self.callbacks = callbacks
42
        self.catch = catch
43
        self.max_score = max_score
44
        self.early_stopping = early_stopping
45
        self.random_state = random_state
46
        self.memory = memory
47
        self.memory_warm_start = memory_warm_start
48
        self.verbosity = verbosity
49
50
        if "progress_bar" in self.verbosity:
51
            self.verbosity = ["progress_bar"]
52
        else:
53
            self.verbosity = []
54
55
    def convert_results2hyper(self):
56
        self.eval_times = np.array(self.opt_algo.eval_times).sum()
57
        self.iter_times = np.array(self.opt_algo.iter_times).sum()
58
59
        if self.opt_algo.best_para is not None:
60
            value = self.hg_conv.para2value(self.opt_algo.best_para)
61
            position = self.hg_conv.position2value(value)
62
            best_para = self.hg_conv.value2para(position)
63
64
            self.best_para = best_para
65
        else:
66
            self.best_para = None
67
68
        self.best_score = self.opt_algo.best_score
69
        self.positions = self.opt_algo.search_data
70
71
        self.search_data = self.hg_conv.positions2results(self.positions)
72
73
        results_dd = self.opt_algo.search_data.drop_duplicates(
74
            subset=self.s_space.dim_keys, keep="first"
75
        )
76
        self.memory_values_df = results_dd[
77
            self.s_space.dim_keys + ["score"]
78
        ].reset_index(drop=True)
79
80
    def _setup_process(self, nth_process):
81
        self.nth_process = nth_process
82
83
        self.hg_conv = HyperGradientConv(self.s_space)
84
85
        initialize = self.hg_conv.conv_initialize(self.initialize)
86
        search_space_positions = self.s_space.positions
87
88
        # conv warm start for smbo from values into positions
89
        if "warm_start_smbo" in self.opt_params:
90
            self.opt_params["warm_start_smbo"] = self.hg_conv.conv_memory_warm_start(
91
                self.opt_params["warm_start_smbo"]
92
            )
93
94
        self.opt_algo = self._OptimizerClass(
95
            search_space=search_space_positions,
96
            initialize=initialize,
97
            random_state=self.random_state,
98
            nth_process=nth_process,
99
            **self.opt_params
100
        )
101
102
        self.conv = self.opt_algo.conv
103
104
    def search(self, nth_process):
105
        self._setup_process(nth_process)
106
107
        gfo_wrapper_model = ObjectiveFunction(
108
            objective_function=self.objective_function,
109
            optimizer=self.opt_algo,
110
            callbacks=self.callbacks,
111
            catch=self.catch,
112
            nth_process=self.nth_process,
113
        )
114
        gfo_wrapper_model.pass_through = self.pass_through
115
116
        memory_warm_start = self.hg_conv.conv_memory_warm_start(self.memory_warm_start)
117
118
        gfo_objective_function = gfo_wrapper_model(self.s_space())
119
120
        self.opt_algo.search(
121
            objective_function=gfo_objective_function,
122
            n_iter=self.n_iter,
123
            max_time=self.max_time,
124
            max_score=self.max_score,
125
            early_stopping=self.early_stopping,
126
            memory=self.memory,
127
            memory_warm_start=memory_warm_start,
128
            verbosity=self.verbosity,
129
        )
130
131
        self.convert_results2hyper()
132
133
        self._add_result_attributes(
134
            self.best_para,
135
            self.best_score,
136
            self.opt_algo.p_bar._best_since_iter,
137
            self.eval_times,
138
            self.iter_times,
139
            self.positions,
140
            self.search_data,
141
            self.memory_values_df,
142
            self.opt_algo.random_seed,
143
        )
144