Passed
Push — master ( d39371...69bf6f )
by Simon
03:38
created

gradient_free_optimizers.search.Search.search()   A

Complexity

Conditions 4

Size

Total Lines 48
Code Lines 35

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 35
nop 8
dl 0
loc 48
rs 9.0399
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import random
7
8
import numpy as np
9
from tqdm import tqdm
10
11
from .init_positions import init_grid_search, init_random_search
12
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
13
from .io_search_processor import init_search
14
15
16
p_bar_dict = {
17
    0: ProgressBarLVL0,
18
    1: ProgressBarLVL1,
19
}
20
21
22
def time_exceeded(start_time, max_time):
23
    run_time = time.time() - start_time
24
    return max_time and run_time > max_time
25
26
27
def set_random_seed(nth_process, random_state):
28
    """Sets the random seed separately for each thread (to avoid getting the same results in each thread)"""
29
    if random_state is None:
30
        random_state = np.random.randint(0, high=2 ** 32 - 2)
31
32
    random.seed(random_state + nth_process)
33
    np.random.seed(random_state + nth_process)
34
35
36
class Search:
37
    def _values2positions(self, values):
38
        init_pos_conv_list = []
39
        values_np = np.array(values)
40
41
        for n, space_dim in enumerate(self.search_space):
42
            pos_1d = values_np[:, n]
43
            init_pos_conv = np.where(space_dim == pos_1d)[0]
44
            init_pos_conv_list.append(init_pos_conv)
45
46
        return init_pos_conv_list
47
48
    def _positions2values(self, positions):
49
        pos_converted = []
50
        positions_np = np.array(positions)
51
52
        for n, space_dim in enumerate(self.search_space):
53
            pos_1d = positions_np[:, n]
54
            pos_conv = np.take(space_dim, pos_1d, axis=0)
55
            pos_converted.append(pos_conv)
56
57
        return list(np.array(pos_converted).T)
58
59
    def _init_positions(self, init_values):
60
        init_positions_list = []
61
62
        if "random" in init_values:
63
            positions = init_random_search(self.space_dim, init_values["random"])
64
            init_positions_list.append(positions)
65
        if "grid" in init_values:
66
            positions = init_grid_search(self.space_dim, init_values["grid"])
67
            init_positions_list.append(positions)
68
        if "warm_start" in init_values:
69
            positions = self._values2positions(init_values["warm_start"])
70
            init_positions_list.append(positions)
71
72
        return [item for sublist in init_positions_list for item in sublist]
73
74
    def search(
75
        self,
76
        objective_function,
77
        n_iter,
78
        init_values={"grid": 7, "random": 3,},
79
        max_time=None,
80
        verbosity=1,
81
        random_state=None,
82
        nth_process=0,
83
    ):
84
        set_random_seed(nth_process, random_state)
85
        start_time = time.time()
86
87
        self.p_bar = p_bar_dict[verbosity]()
88
        self.p_bar.init(nth_process, n_iter, objective_function)
89
90
        init_positions = self._init_positions(init_values)
91
92
        # loop to initialize N positions
93
        for init_position in init_positions:
94
            start_time_iter = time.time()
95
            self.init_pos(init_position)
96
97
            start_time_eval = time.time()
98
            score_new = objective_function(init_position)
99
            self.p_bar.update(1, score_new)
100
            self.eval_times.append(time.time() - start_time_eval)
101
102
            self.evaluate(score_new)
103
            self.iter_times.append(time.time() - start_time_iter)
104
105
        # loop to do the iterations
106
        for nth_iter in range(len(init_positions), n_iter):
107
            start_time_iter = time.time()
108
            pos_new = self.iterate()
109
110
            start_time_eval = time.time()
111
            score_new = objective_function(pos_new)
112
            self.p_bar.update(1, score_new)
113
            self.eval_times.append(time.time() - start_time_eval)
114
115
            self.evaluate(score_new)
116
            self.iter_times.append(time.time() - start_time_iter)
117
118
            if time_exceeded(start_time, max_time):
119
                break
120
121
        self.p_bar.close()
122
123