Passed
Push — master ( b0b7f3...e3d174 )
by Simon
04:31
created

gradient_free_optimizers.search.Search.search()   B

Complexity

Conditions 6

Size

Total Lines 76
Code Lines 60

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 60
nop 12
dl 0
loc 76
rs 7.3757
c 0
b 0
f 0

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import random
7
8
import numpy as np
9
import pandas as pd
10
11
from .init_positions import Initializer
12
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
13
from .times_tracker import TimesTracker
14
from .results_manager import ResultsManager
15
from .memory import Memory
16
from .converter import Converter
17
from .print_info import print_info
18
19
p_bar_dict = {
20
    False: ProgressBarLVL0,
21
    True: ProgressBarLVL1,
22
}
23
24
25
def time_exceeded(start_time, max_time):
26
    run_time = time.time() - start_time
27
    return max_time and run_time > max_time
28
29
30
def score_exceeded(score_best, max_score):
31
    return max_score and score_best >= max_score
32
33
34
def set_random_seed(nth_process, random_state):
35
    """
36
    Sets the random seed separately for each thread
37
    (to avoid getting the same results in each thread)
38
    """
39
    if nth_process is None:
40
        nth_process = 0
41
42
    if random_state is None:
43
        random_state = np.random.randint(0, high=2 ** 32 - 2)
44
45
    random.seed(random_state + nth_process)
46
    np.random.seed(random_state + nth_process)
47
48
49
class Search(TimesTracker):
50
    def __init__(self):
51
        super().__init__()
52
53
        self.optimizers = []
54
        self.new_results_list = []
55
        self.all_results_list = []
56
57
    @TimesTracker.eval_time
58
    def _score(self, pos):
59
        return self.score(pos)
60
61
    def _init_memory(self, memory):
62
        memory_warm_start = self.memory_warm_start
63
64
        self.memory_dict = {}
65
        self.memory_dict_new = {}
66
67
        if isinstance(memory_warm_start, pd.DataFrame):
68
            parameter = set(self.search_space.keys())
69
            memory_para = set(memory_warm_start.columns)
70
71
            if parameter <= memory_para:
72
                values_list = list(
73
                    memory_warm_start[list(self.search_space.keys())].values
74
                )
75
                scores = memory_warm_start["score"]
76
77
                value_tuple_list = list(map(tuple, values_list))
78
                self.memory_dict = dict(zip(value_tuple_list, scores))
79
            else:
80
                missing = parameter - memory_para
81
82
                print(
83
                    "\nWarning:",
84
                    '"{}"'.format(*missing),
85
                    "is in search_space but not in memory dataframe",
86
                )
87
                print(
88
                    "Optimization run will continue "
89
                    "without memory warm start\n"
90
                )
91
92
    @TimesTracker.iter_time
93
    def _initialization(self, init_pos):
94
        self.init_pos(init_pos)
95
96
        score_new = self._score(init_pos)
97
        self.evaluate(score_new)
98
99
        self.p_bar.update(score_new, init_pos)
100
101
    @TimesTracker.iter_time
102
    def _iteration(self):
103
        pos_new = self.iterate()
104
105
        score_new = self._score(pos_new)
106
        self.evaluate(score_new)
107
108
        self.p_bar.update(score_new, pos_new)
109
110
    def _init_search(self):
111
        self._init_memory(self.memory)
112
        self.p_bar = p_bar_dict[self.progress_bar](
113
            self.nth_process, self.n_iter, self.objective_function
114
        )
115
        set_random_seed(self.nth_process, self.random_state)
116
117
        if self.warm_start is not None:
118
            self.initialize["warm_start"] = self.warm_start
119
120
        # get init positions
121
        init = Initializer(self.conv)
122
        init_positions = init.set_pos(self.initialize)
123
124
        return init_positions
125
126
    def _early_stop(self):
127
        if time_exceeded(self.start_time, self.max_time):
128
            return True
129
        elif score_exceeded(self.p_bar.score_best, self.max_score):
130
            return True
131
        else:
132
            return False
133
134
    def _init_verb_dict(self, verb_dict):
135
        if verb_dict in [None, False]:
136
            return {
137
                "progress_bar": False,
138
                "print_results": False,
139
                "print_times": False,
140
            }
141
142
        verb_default = {
143
            "progress_bar": True,
144
            "print_results": True,
145
            "print_times": True,
146
        }
147
148
        for verb_key in verb_default.keys():
149
            if verb_key not in verb_dict:
150
                verb_dict[verb_key] = verb_default[verb_key]
151
152
        return verb_dict
153
154
    def search(
155
        self,
156
        objective_function,
157
        n_iter,
158
        initialize={"grid": 8, "random": 4, "vertices": 8},
159
        warm_start=None,
160
        max_time=None,
161
        max_score=None,
162
        memory=True,
163
        memory_warm_start=None,
164
        verbosity={
165
            "progress_bar": True,
166
            "print_results": True,
167
            "print_times": True,
168
        },
169
        random_state=None,
170
        nth_process=None,
171
    ):
172
        self.start_time = time.time()
173
174
        verbosity = self._init_verb_dict(verbosity)
175
176
        self.objective_function = objective_function
177
        self.n_iter = n_iter
178
        self.initialize = initialize
179
        self.warm_start = warm_start
180
        self.max_time = max_time
181
        self.max_score = max_score
182
        self.memory = memory
183
        self.memory_warm_start = memory_warm_start
184
        self.progress_bar = verbosity["progress_bar"]
185
        self.random_state = random_state
186
        self.nth_process = nth_process
187
188
        conv = Converter(self.search_space)
189
        self.conv = conv
190
        results = ResultsManager(objective_function, conv)
191
        init_positions = self._init_search()
192
193
        if memory is True:
194
            mem = Memory(memory_warm_start, conv)
195
            self.score = mem.memory(results.score)
196
        else:
197
            self.score = results.score
198
199
        # loop to initialize N positions
200
        for init_pos, nth_iter in zip(init_positions, range(n_iter)):
201
            if self._early_stop():
202
                break
203
            self._initialization(init_pos)
204
205
        # loop to do the iterations
206
        for nth_iter in range(len(init_positions), n_iter):
207
            if self._early_stop():
208
                break
209
            self._iteration()
210
211
        self.results = pd.DataFrame(results.results_list)
212
213
        self.best_score = self.p_bar.score_best
214
        self.best_value = conv.position2value(self.p_bar.pos_best)
215
        self.best_para = conv.value2para(self.best_value)
216
217
        eval_time = np.array(self.eval_times).sum()
218
        iter_time = np.array(self.iter_times).sum()
219
220
        self.p_bar.close()
221
222
        print_info(
223
            verbosity,
224
            self.objective_function,
225
            self.best_score,
226
            self.best_para,
227
            eval_time,
228
            iter_time,
229
            self.n_iter,
230
        )
231
232