Passed
Push — master ( bac9d4...455948 )
by Simon
04:25
created

gradient_free_optimizers.search.Search.search()   B

Complexity

Conditions 6

Size

Total Lines 75
Code Lines 58

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 58
nop 12
dl 0
loc 75
rs 7.4412
c 0
b 0
f 0

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import random
7
8
import numpy as np
9
import pandas as pd
10
11
from .init_positions import Initializer
12
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
13
from .times_tracker import TimesTracker
14
from .results_manager import ResultsManager
15
from .memory import Memory
16
from .print_info import print_info
17
18
p_bar_dict = {
19
    False: ProgressBarLVL0,
20
    True: ProgressBarLVL1,
21
}
22
23
24
def time_exceeded(start_time, max_time):
25
    run_time = time.time() - start_time
26
    return max_time and run_time > max_time
27
28
29
def score_exceeded(score_best, max_score):
30
    return max_score and score_best >= max_score
31
32
33
def set_random_seed(nth_process, random_state):
34
    """
35
    Sets the random seed separately for each thread
36
    (to avoid getting the same results in each thread)
37
    """
38
    if nth_process is None:
39
        nth_process = 0
40
41
    if random_state is None:
42
        random_state = np.random.randint(0, high=2 ** 32 - 2)
43
44
    random.seed(random_state + nth_process)
45
    np.random.seed(random_state + nth_process)
46
47
48
class Search(TimesTracker):
49
    def __init__(self):
50
        super().__init__()
51
52
        self.optimizers = []
53
        self.new_results_list = []
54
        self.all_results_list = []
55
56
    @TimesTracker.eval_time
57
    def _score(self, pos):
58
        return self.score(pos)
59
60
    def _init_memory(self, memory):
61
        memory_warm_start = self.memory_warm_start
62
63
        self.memory_dict = {}
64
        self.memory_dict_new = {}
65
66
        if isinstance(memory_warm_start, pd.DataFrame):
67
            parameter = set(self.conv.search_space.keys())
68
            memory_para = set(memory_warm_start.columns)
69
70
            if parameter <= memory_para:
71
                values_list = list(
72
                    memory_warm_start[
73
                        list(self.conv.search_space.keys())
74
                    ].values
75
                )
76
                scores = memory_warm_start["score"]
77
78
                value_tuple_list = list(map(tuple, values_list))
79
                self.memory_dict = dict(zip(value_tuple_list, scores))
80
            else:
81
                missing = parameter - memory_para
82
83
                print(
84
                    "\nWarning:",
85
                    '"{}"'.format(*missing),
86
                    "is in search_space but not in memory dataframe",
87
                )
88
                print(
89
                    "Optimization run will continue "
90
                    "without memory warm start\n"
91
                )
92
93
    @TimesTracker.iter_time
94
    def _initialization(self, init_pos):
95
        self.init_pos(init_pos)
96
97
        score_new = self._score(init_pos)
98
        self.evaluate(score_new)
99
100
        self.p_bar.update(score_new, init_pos)
101
102
    @TimesTracker.iter_time
103
    def _iteration(self):
104
        pos_new = self.iterate()
105
106
        score_new = self._score(pos_new)
107
        self.evaluate(score_new)
108
109
        self.p_bar.update(score_new, pos_new)
110
111
    def _init_search(self):
112
        self._init_memory(self.memory)
113
        self.p_bar = p_bar_dict[self.progress_bar](
114
            self.nth_process, self.n_iter, self.objective_function
115
        )
116
        set_random_seed(self.nth_process, self.random_state)
117
118
        if self.warm_start is not None:
119
            self.initialize["warm_start"] = self.warm_start
120
121
        # get init positions
122
        init = Initializer(self.conv)
123
        init_positions = init.set_pos(self.initialize)
124
125
        return init_positions
126
127
    def _early_stop(self):
128
        if time_exceeded(self.start_time, self.max_time):
129
            return True
130
        elif score_exceeded(self.p_bar.score_best, self.max_score):
131
            return True
132
        else:
133
            return False
134
135
    def _init_verb_dict(self, verb_dict):
136
        if verb_dict in [None, False]:
137
            return {
138
                "progress_bar": False,
139
                "print_results": False,
140
                "print_times": False,
141
            }
142
143
        verb_default = {
144
            "progress_bar": True,
145
            "print_results": True,
146
            "print_times": True,
147
        }
148
149
        for verb_key in verb_default.keys():
150
            if verb_key not in verb_dict:
151
                verb_dict[verb_key] = verb_default[verb_key]
152
153
        return verb_dict
154
155
    def print_info(self, *args):
156
        print_info(*args)
157
158
    def search(
159
        self,
160
        objective_function,
161
        n_iter,
162
        initialize={"grid": 8, "random": 4, "vertices": 8},
163
        warm_start=None,
164
        max_time=None,
165
        max_score=None,
166
        memory=True,
167
        memory_warm_start=None,
168
        verbosity={
169
            "progress_bar": True,
170
            "print_results": True,
171
            "print_times": True,
172
        },
173
        random_state=None,
174
        nth_process=None,
175
    ):
176
177
        self.start_time = time.time()
178
179
        verbosity = self._init_verb_dict(verbosity)
180
181
        self.objective_function = objective_function
182
        self.n_iter = n_iter
183
        self.initialize = initialize
184
        self.warm_start = warm_start
185
        self.max_time = max_time
186
        self.max_score = max_score
187
        self.memory = memory
188
        self.memory_warm_start = memory_warm_start
189
        self.progress_bar = verbosity["progress_bar"]
190
        self.random_state = random_state
191
        self.nth_process = nth_process
192
193
        results = ResultsManager(objective_function, self.conv)
194
        init_positions = self._init_search()
195
196
        if memory is True:
197
            mem = Memory(memory_warm_start, self.conv)
198
            self.score = mem.memory(results.score)
199
        else:
200
            self.score = results.score
201
202
        # loop to initialize N positions
203
        for init_pos, nth_iter in zip(init_positions, range(n_iter)):
204
            if self._early_stop():
205
                break
206
            self._initialization(init_pos)
207
208
        # loop to do the iterations
209
        for nth_iter in range(len(init_positions), n_iter):
210
            if self._early_stop():
211
                break
212
            self._iteration()
213
214
        self.results = pd.DataFrame(results.results_list)
215
216
        self.best_score = self.p_bar.score_best
217
        self.best_value = self.conv.position2value(self.p_bar.pos_best)
218
        self.best_para = self.conv.value2para(self.best_value)
219
220
        eval_time = np.array(self.eval_times).sum()
221
        iter_time = np.array(self.iter_times).sum()
222
223
        self.p_bar.close()
224
225
        self.print_info(
226
            verbosity,
227
            self.objective_function,
228
            self.best_score,
229
            self.best_para,
230
            eval_time,
231
            iter_time,
232
            self.n_iter,
233
        )
234
235