Passed
Pull Request — master (#101)
by Simon
01:35
created

hyperactive.optimizers.search.Search.setup()   A

Complexity

Conditions 1

Size

Total Lines 26
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 24
nop 12
dl 0
loc 26
rs 9.304
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
from .objective_function import ObjectiveFunction
6
from .hyper_gradient_conv import HyperGradientConv
7
from .optimizer_attributes import OptimizerAttributes
8
from .constraint import Constraint
9
10
11
class Search(OptimizerAttributes):
12
    max_time: float
13
    nth_process: int
14
15
    def __init__(self, optimizer_class, opt_params):
16
        super().__init__()
17
        self.optimizer_class = optimizer_class
18
        self.opt_params = opt_params
19
20
    def setup(
21
        self,
22
        experiment,
23
        s_space,
24
        n_iter,
25
        initialize,
26
        constraints,
27
        pass_through,
28
        max_score,
29
        early_stopping,
30
        random_state,
31
        memory,
32
        memory_warm_start,
33
    ):
34
        self.experiment = experiment
35
        self.s_space = s_space
36
        self.n_iter = n_iter
37
38
        self.initialize = initialize
39
        self.constraints = constraints
40
        self.pass_through = pass_through
41
        self.max_score = max_score
42
        self.early_stopping = early_stopping
43
        self.random_state = random_state
44
        self.memory = memory
45
        self.memory_warm_start = memory_warm_start
46
47
    def pass_args(self, max_time, nth_process, verbosity):
48
        self.max_time = max_time
49
        self.nth_process = nth_process
50
51
        if "progress_bar" in verbosity:
52
            self.verbosity = ["progress_bar"]
53
        else:
54
            self.verbosity = []
55
56 View Code Duplication
    def convert_results2hyper(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
57
        self.eval_times = sum(self.gfo_optimizer.eval_times)
58
        self.iter_times = sum(self.gfo_optimizer.iter_times)
59
60
        if self.gfo_optimizer.best_para is not None:
61
            value = self.hg_conv.para2value(self.gfo_optimizer.best_para)
62
            position = self.hg_conv.position2value(value)
63
            best_para = self.hg_conv.value2para(position)
64
            self.best_para = best_para
65
        else:
66
            self.best_para = None
67
68
        self.best_score = self.gfo_optimizer.best_score
69
        self.positions = self.gfo_optimizer.search_data
70
        self.search_data = self.hg_conv.positions2results(self.positions)
71
72
        results_dd = self.gfo_optimizer.search_data.drop_duplicates(
73
            subset=self.s_space.dim_keys, keep="first"
74
        )
75
        self.memory_values_df = results_dd[
76
            self.s_space.dim_keys + ["score"]
77
        ].reset_index(drop=True)
78
79 View Code Duplication
    def _setup_process(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
80
        self.hg_conv = HyperGradientConv(self.s_space)
81
82
        initialize = self.hg_conv.conv_initialize(self.initialize)
83
        search_space_positions = self.s_space.positions
84
85
        # conv warm start for smbo from values into positions
86
        if "warm_start_smbo" in self.opt_params:
87
            self.opt_params["warm_start_smbo"] = self.hg_conv.conv_memory_warm_start(
88
                self.opt_params["warm_start_smbo"]
89
            )
90
91
        gfo_constraints = [
92
            Constraint(constraint, self.s_space) for constraint in self.constraints
93
        ]
94
95
        self.gfo_optimizer = self.optimizer_class(
96
            search_space=search_space_positions,
97
            initialize=initialize,
98
            constraints=gfo_constraints,
99
            random_state=self.random_state,
100
            nth_process=self.nth_process,
101
            **self.opt_params,
102
        )
103
104
        self.conv = self.gfo_optimizer.conv
105
106
    def _search(self, p_bar):
107
        self._setup_process()
108
109
        gfo_wrapper_model = ObjectiveFunction(
110
            experiment=self.experiment,
111
        )
112
        gfo_wrapper_model.pass_through = self.pass_through
113
114
        memory_warm_start = self.hg_conv.conv_memory_warm_start(self.memory_warm_start)
115
116
        gfo_objective_function = gfo_wrapper_model(self.s_space())
117
118
        self.gfo_optimizer.init_search(
119
            gfo_objective_function,
120
            self.n_iter,
121
            self.max_time,
122
            self.max_score,
123
            self.early_stopping,
124
            self.memory,
125
            memory_warm_start,
126
            False,
127
        )
128
        for nth_iter in range(self.n_iter):
129
            if p_bar:
130
                p_bar.set_description(
131
                    "["
132
                    + str(self.nth_process)
133
                    + "] "
134
                    + str(self.experiment.__class__.__name__)
135
                    + " ("
136
                    + self.optimizer_class.name
137
                    + ")",
138
                )
139
140
            self.gfo_optimizer.search_step(nth_iter)
141
            if self.gfo_optimizer.stop.check():
142
                break
143
144
            if p_bar:
145
                p_bar.set_postfix(
146
                    best_score=str(self.gfo_optimizer.score_best),
147
                    best_pos=str(self.gfo_optimizer.pos_best),
148
                    best_iter=str(self.gfo_optimizer.p_bar._best_since_iter),
149
                )
150
151
                p_bar.update(1)
152
                p_bar.refresh()
153
154
        self.gfo_optimizer.finish_search()
155
156
        self.convert_results2hyper()
157
158
        self._add_result_attributes(
159
            self.best_para,
160
            self.best_score,
161
            self.gfo_optimizer.p_bar._best_since_iter,
162
            self.eval_times,
163
            self.iter_times,
164
            self.search_data,
165
            self.gfo_optimizer.random_seed,
166
        )
167