Passed
Push — master ( 0c1576...8c2877 )
by Simon
04:40 queued 12s
created

HyperOptimizer.setup_search()   A

Complexity

Conditions 2

Size

Total Lines 35
Code Lines 31

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 31
nop 14
dl 0
loc 35
rs 9.1359
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import numpy as np
6
import pandas as pd
7
8
9
from .objective_function import ObjectiveFunction
10
from .hyper_gradient_conv import HyperGradientConv
11
from .base_optimizer import BaseOptimizer
12
13
14
class HyperOptimizer(BaseOptimizer):
15
    def __init__(self, **opt_params):
16
        super().__init__()
17
        self.opt_params = opt_params
18
19
    def setup_search(
20
        self,
21
        objective_function,
22
        s_space,
23
        n_iter,
24
        initialize,
25
        pass_through,
26
        callbacks,
27
        catch,
28
        max_score,
29
        early_stopping,
30
        random_state,
31
        memory,
32
        memory_warm_start,
33
        verbosity,
34
    ):
35
        self.objective_function = objective_function
36
        self.s_space = s_space
37
        self.n_iter = n_iter
38
39
        self.initialize = initialize
40
        self.pass_through = pass_through
41
        self.callbacks = callbacks
42
        self.catch = catch
43
        self.max_score = max_score
44
        self.early_stopping = early_stopping
45
        self.random_state = random_state
46
        self.memory = memory
47
        self.memory_warm_start = memory_warm_start
48
        self.verbosity = verbosity
49
50
        if "progress_bar" in self.verbosity:
51
            self.verbosity = ["progress_bar"]
52
        else:
53
            self.verbosity = []
54
55
    def convert_results2hyper(self):
56
        self.eval_times = np.array(self.opt_algo.eval_times).sum()
57
        self.iter_times = np.array(self.opt_algo.iter_times).sum()
58
59
        if self.opt_algo.best_para is not None:
60
            value = self.hg_conv.para2value(self.opt_algo.best_para)
61
            position = self.hg_conv.position2value(value)
62
            best_para = self.hg_conv.value2para(position)
63
64
            self.best_para = best_para
65
        else:
66
            self.best_para = None
67
68
        self.best_score = self.opt_algo.best_score
69
        self.positions = self.opt_algo.search_data
70
71
        self.search_data = self.hg_conv.positions2results(self.positions)
72
73
        results_dd = self.opt_algo.search_data.drop_duplicates(
74
            subset=self.s_space.dim_keys, keep="first"
75
        )
76
        self.memory_values_df = results_dd[
77
            self.s_space.dim_keys + ["score"]
78
        ].reset_index(drop=True)
79
80
    def _setup_process(self, nth_process):
81
        self.nth_process = nth_process
82
83
        self.hg_conv = HyperGradientConv(self.s_space)
84
85
        initialize = self.hg_conv.conv_initialize(self.initialize)
86
        search_space_positions = self.s_space.positions
87
88
        # conv warm start for smbo from values into positions
89
        if "warm_start_smbo" in self.opt_params:
90
            self.opt_params["warm_start_smbo"] = self.hg_conv.conv_memory_warm_start(
91
                self.opt_params["warm_start_smbo"]
92
            )
93
94
        self.opt_algo = self._OptimizerClass(
95
            search_space=search_space_positions,
96
            initialize=initialize,
97
            random_state=self.random_state,
98
            nth_process=nth_process,
99
            **self.opt_params
100
        )
101
102
        self.conv = self.opt_algo.conv
103
104
    def search(self, nth_process):
105
        self._setup_process(nth_process)
106
107
        gfo_wrapper_model = ObjectiveFunction(
108
            objective_function=self.objective_function,
109
            optimizer=self.opt_algo,
110
            callbacks=self.callbacks,
111
            catch=self.catch,
112
            nth_process=self.nth_process,
113
        )
114
        gfo_wrapper_model.pass_through = self.pass_through
115
116
        memory_warm_start = self.hg_conv.conv_memory_warm_start(self.memory_warm_start)
117
118
        gfo_objective_function = gfo_wrapper_model(self.s_space())
119
120
        self.opt_algo.search(
121
            objective_function=gfo_objective_function,
122
            n_iter=self.n_iter,
123
            max_time=self.max_time,
124
            max_score=self.max_score,
125
            early_stopping=self.early_stopping,
126
            memory=self.memory,
127
            memory_warm_start=memory_warm_start,
128
            verbosity=self.verbosity,
129
        )
130
131
        self.convert_results2hyper()
132
133
        self._add_result_attributes(
134
            self.best_para,
135
            self.best_score,
136
            self.opt_algo.p_bar._best_since_iter,
137
            self.eval_times,
138
            self.iter_times,
139
            self.positions,
140
            self.search_data,
141
            self.memory_values_df,
142
            self.opt_algo.random_seed,
143
        )
144