Passed
Push — master ( 97c085...baa350 )
by Simon
04:37
created

gradient_free_optimizers.search.Search.search()   C

Complexity

Conditions 9

Size

Total Lines 84
Code Lines 64

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 64
nop 11
dl 0
loc 84
rs 5.8448
c 0
b 0
f 0

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import random
7
8
import numpy as np
9
import pandas as pd
10
11
from multiprocessing.managers import DictProxy
12
13
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
14
from .times_tracker import TimesTracker
15
from .memory import Memory
16
from .print_info import print_info
17
from .init_positions import Initializer
18
19
20
def time_exceeded(start_time, max_time):
21
    run_time = time.time() - start_time
22
    return max_time and run_time > max_time
23
24
25
def score_exceeded(score_best, max_score):
26
    return max_score and score_best >= max_score
27
28
29
def no_change(score_new_list, early_stopping):
30
    if "n_iter_no_change" not in early_stopping:
31
        print(
32
            "Warning n_iter_no_change-parameter must be set in order for early stopping to work"
33
        )
34
        return False
35
36
    n_iter_no_change = early_stopping["n_iter_no_change"]
37
    if len(score_new_list) <= n_iter_no_change:
38
        return False
39
40
    scores_np = np.array(score_new_list)
41
42
    max_score = max(score_new_list)
43
    max_index = np.argmax(scores_np)
44
    length_pos = len(score_new_list)
45
46
    diff = length_pos - max_index
47
48
    if diff > n_iter_no_change:
49
        return True
50
51
    first_n = length_pos - n_iter_no_change
52
    scores_first_n = score_new_list[:first_n]
53
54
    max_first_n = max(scores_first_n)
55
56
    if "tol_abs" in early_stopping and early_stopping["tol_abs"] is not None:
57
        tol_abs = early_stopping["tol_abs"]
58
59
        if abs(max_first_n - max_score) < tol_abs:
60
            return True
61
62
    if "tol_rel" in early_stopping and early_stopping["tol_rel"] is not None:
63
        tol_rel = early_stopping["tol_rel"]
64
65
        percent_imp = ((max_score - max_first_n) / abs(max_first_n)) * 100
66
        if percent_imp < tol_rel:
67
            return True
68
69
70
def set_random_seed(nth_process, random_state):
71
    """
72
    Sets the random seed separately for each thread
73
    (to avoid getting the same results in each thread)
74
    """
75
    if nth_process is None:
76
        nth_process = 0
77
78
    if random_state is None:
79
        random_state = np.random.randint(0, high=2 ** 31 - 2, dtype=np.int64)
80
81
    random.seed(random_state + nth_process)
82
    np.random.seed(random_state + nth_process)
83
84
85
class Search(TimesTracker):
86
    def __init__(self):
87
        super().__init__()
88
89
        self.optimizers = []
90
        self.new_results_list = []
91
        self.all_results_list = []
92
93
    @TimesTracker.eval_time
94
    def _score(self, pos):
95
        return self.score(pos)
96
97
    @TimesTracker.iter_time
98
    def _initialization(self, init_pos, nth_iter):
99
        self.nth_iter = nth_iter
100
        self.best_score = self.p_bar.score_best
101
102
        self.init_pos(init_pos)
103
104
        score_new = self._score(init_pos)
105
        self.evaluate(score_new)
106
107
        self.p_bar.update(score_new, init_pos, nth_iter)
108
109
    @TimesTracker.iter_time
110
    def _iteration(self, nth_iter):
111
        self.nth_iter = nth_iter
112
        self.best_score = self.p_bar.score_best
113
114
        pos_new = self.iterate()
115
116
        score_new = self._score(pos_new)
117
        self.evaluate(score_new)
118
119
        self.p_bar.update(score_new, pos_new, nth_iter)
120
121
    def _init_search(self):
122
        self.stop = False
123
124
        if "progress_bar" in self.verbosity:
125
            self.p_bar = ProgressBarLVL1(
126
                self.nth_process, self.n_iter, self.objective_function
127
            )
128
        else:
129
            self.p_bar = ProgressBarLVL0(
130
                self.nth_process, self.n_iter, self.objective_function
131
            )
132
133
        set_random_seed(self.nth_process, self.random_state)
134
135
    def _early_stop(self):
136
        if self.stop:
137
            return
138
139
        if self.max_time and time_exceeded(self.start_time, self.max_time):
140
            self.stop = True
141
        elif self.max_score and score_exceeded(self.p_bar.score_best, self.max_score):
142
            self.stop = True
143
        elif self.early_stopping and no_change(
144
            self.score_new_list, self.early_stopping
145
        ):
146
            self.stop = True
147
148
    def print_info(self, *args):
149
        print_info(*args)
150
151
    def search(
152
        self,
153
        objective_function,
154
        n_iter,
155
        max_time=None,
156
        max_score=None,
157
        early_stopping=None,
158
        memory=True,
159
        memory_warm_start=None,
160
        verbosity=["progress_bar", "print_results", "print_times"],
161
        random_state=None,
162
        nth_process=None,
163
    ):
164
        self.start_time = time.time()
165
166
        if verbosity is False:
167
            verbosity = []
168
169
        self.objective_function = objective_function
170
        self.n_iter = n_iter
171
        self.max_time = max_time
172
        self.max_score = max_score
173
        self.early_stopping = early_stopping
174
        self.memory = memory
175
        self.memory_warm_start = memory_warm_start
176
        self.verbosity = verbosity
177
        self.random_state = random_state
178
        self.nth_process = nth_process
179
180
        self._init_search()
181
182
        # get init positions
183
        init = Initializer(self.conv)
184
        self.init_positions = init.set_pos(self.initialize)
185
186
        if isinstance(memory, DictProxy):
187
            mem = Memory(memory_warm_start, self.conv, dict_proxy=memory)
188
            self.score = self.results_mang.score(mem.memory(objective_function))
189
        elif memory is True:
190
            mem = Memory(memory_warm_start, self.conv)
191
            self.score = self.results_mang.score(mem.memory(objective_function))
192
        else:
193
            self.score = self.results_mang.score(objective_function)
194
195
        # loop to initialize N positions
196
        for init_pos, nth_iter in zip(self.init_positions, range(n_iter)):
197
            self._early_stop()
198
            if self.stop:
199
                break
200
            self._initialization(init_pos, nth_iter)
201
202
        self.finish_initialization()
203
204
        # loop to do the iterations
205
        for nth_iter in range(len(self.init_positions), n_iter):
206
            self._early_stop()
207
            if self.stop:
208
                break
209
            self._iteration(nth_iter)
210
211
        self.results = pd.DataFrame(self.results_mang.results_list)
212
213
        self.best_score = self.p_bar.score_best
214
        self.best_value = self.conv.position2value(self.p_bar.pos_best)
215
        self.best_para = self.conv.value2para(self.best_value)
216
217
        self.results["eval_time"] = self.eval_times
218
        self.results["iter_time"] = self.iter_times
219
220
        if memory is not False:
221
            self.memory_dict = mem.memory_dict
0 ignored issues
show
introduced by
The variable mem does not seem to be defined for all execution paths.
Loading history...
222
        else:
223
            self.memory_dict = {}
224
225
        self.p_bar.close()
226
227
        self.print_info(
228
            verbosity,
229
            self.objective_function,
230
            self.best_score,
231
            self.best_para,
232
            self.eval_times,
233
            self.iter_times,
234
            self.n_iter,
235
        )
236