Passed
Push — master ( 23d304...5d9e6c )
by Simon
01:08
created

gradient_free_optimizers.search.Search.search()   B

Complexity

Conditions 4

Size

Total Lines 53
Code Lines 38

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 38
nop 9
dl 0
loc 53
rs 8.968
c 0
b 0
f 0

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import random
7
8
import numpy as np
9
from tqdm import tqdm
10
11
from .init_positions import init_grid_search, init_random_search
12
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
13
14
15
p_bar_dict = {
16
    0: ProgressBarLVL0,
17
    1: ProgressBarLVL1,
18
}
19
20
21
def time_exceeded(start_time, max_time):
22
    run_time = time.time() - start_time
23
    return max_time and run_time > max_time
24
25
26
def set_random_seed(nth_process, random_state):
27
    """Sets the random seed separately for each thread (to avoid getting the same results in each thread)"""
28
    if random_state is None:
29
        random_state = np.random.randint(0, high=2 ** 32 - 2)
30
31
    random.seed(random_state + nth_process)
32
    np.random.seed(random_state + nth_process)
33
34
35
class Search:
36
    def _values2positions(self, values):
37
        init_pos_conv_list = []
38
        values_np = np.array(values)
39
40
        for n, space_dim in enumerate(self.search_space):
41
            pos_1d = values_np[:, n]
42
            init_pos_conv = np.where(space_dim == pos_1d)[0]
43
            init_pos_conv_list.append(init_pos_conv)
44
45
        return init_pos_conv_list
46
47
    def _positions2values(self, positions):
48
        pos_converted = []
49
        positions_np = np.array(positions)
50
51
        for n, space_dim in enumerate(self.search_space):
52
            pos_1d = positions_np[:, n]
53
            pos_conv = np.take(space_dim, pos_1d, axis=0)
54
            pos_converted.append(pos_conv)
55
56
        return list(np.array(pos_converted).T)
57
58
    def _init_values(self, init_values):
59
        init_positions_list = []
60
61
        if "random" in init_values:
62
            positions = init_random_search(self.space_dim, init_values["random"])
63
            init_positions_list.append(positions)
64
        if "grid" in init_values:
65
            positions = init_grid_search(self.space_dim, init_values["grid"])
66
            init_positions_list.append(positions)
67
        if "warm_start" in init_values:
68
            positions = self._values2positions(init_values["warm_start"])
69
            init_positions_list.append(positions)
70
71
        return [item for sublist in init_positions_list for item in sublist]
72
73
    def _position2value(self, position):
74
        value = []
75
76
        for n, space_dim in enumerate(self.search_space):
77
            value.append(space_dim[position[n]])
78
79
        return value
80
81
    def _score_mem(self, pos):
82
        pos_tuple = tuple(pos)
83
84
        if pos_tuple in self.memory_dict:
85
            return self.memory_dict[pos_tuple]
86
        else:
87
            score = self.objective_function(pos)
88
            self.memory_dict[pos_tuple] = score
89
            return score
90
91
    def _init_memory(self, memory):
92
        if memory == False:
93
            self._score = self.objective_function
94
        elif memory == True:
95
            self._score = self._score_mem
96
            self.memory_dict = {}
97
        elif isinstance(memory, dict):
98
            self._score = self._score_mem
99
100
            values_list = memory["values"]
101
            scores = memory["scores"]
102
103
            value_tuple_list = list(map(tuple, values_list))
104
            self.memory_dict = dict(zip(value_tuple_list, scores))
105
106
    def search(
107
        self,
108
        objective_function,
109
        n_iter,
110
        initialize={"grid": 7, "random": 3,},
111
        max_time=None,
112
        memory=True,
113
        verbosity=1,
114
        random_state=None,
115
        nth_process=0,
116
    ):
117
        self.objective_function = objective_function
118
        self._init_memory(memory)
119
120
        set_random_seed(nth_process, random_state)
121
        start_time = time.time()
122
123
        self.p_bar = p_bar_dict[verbosity](nth_process, n_iter, objective_function)
124
125
        init_values = self._init_values(initialize)
126
127
        # loop to initialize N positions
128
        for init_position in init_values:
129
            start_time_iter = time.time()
130
            self.init_pos(init_position)
131
132
            start_time_eval = time.time()
133
            score_new = self._score(init_position)
134
            self.p_bar.update(1, score_new)
135
            self.eval_times.append(time.time() - start_time_eval)
136
137
            self.evaluate(score_new)
138
            self.iter_times.append(time.time() - start_time_iter)
139
140
        # loop to do the iterations
141
        for nth_iter in range(len(init_values), n_iter):
142
            start_time_iter = time.time()
143
            pos_new = self.iterate()
144
145
            value_new = self._position2value(pos_new)
146
147
            start_time_eval = time.time()
148
            score_new = self._score(value_new)
149
            self.p_bar.update(1, score_new)
150
            self.eval_times.append(time.time() - start_time_eval)
151
152
            self.evaluate(score_new)
153
            self.iter_times.append(time.time() - start_time_iter)
154
155
            if time_exceeded(start_time, max_time):
156
                break
157
158
        self.p_bar.close()
159
160