Passed
Push — master ( 23d304...5d9e6c )
by Simon
01:08
created

gradient_free_optimizers.search   A

Complexity

Total Complexity 23

Size/Duplication

Total Lines 159
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 110
dl 0
loc 159
rs 10
c 0
b 0
f 0
wmc 23

7 Methods

Rating   Name   Duplication   Size   Complexity  
A Search._values2positions() 0 10 2
A Search._positions2values() 0 10 2
A Search._position2value() 0 7 2
A Search._init_values() 0 14 4
B Search.search() 0 53 4
A Search._init_memory() 0 14 4
A Search._score_mem() 0 9 2

2 Functions

Rating   Name   Duplication   Size   Complexity  
A time_exceeded() 0 3 1
A set_random_seed() 0 7 2
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import random
7
8
import numpy as np
9
from tqdm import tqdm
10
11
from .init_positions import init_grid_search, init_random_search
12
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
13
14
15
p_bar_dict = {
16
    0: ProgressBarLVL0,
17
    1: ProgressBarLVL1,
18
}
19
20
21
def time_exceeded(start_time, max_time):
22
    run_time = time.time() - start_time
23
    return max_time and run_time > max_time
24
25
26
def set_random_seed(nth_process, random_state):
27
    """Sets the random seed separately for each thread (to avoid getting the same results in each thread)"""
28
    if random_state is None:
29
        random_state = np.random.randint(0, high=2 ** 32 - 2)
30
31
    random.seed(random_state + nth_process)
32
    np.random.seed(random_state + nth_process)
33
34
35
class Search:
36
    def _values2positions(self, values):
37
        init_pos_conv_list = []
38
        values_np = np.array(values)
39
40
        for n, space_dim in enumerate(self.search_space):
41
            pos_1d = values_np[:, n]
42
            init_pos_conv = np.where(space_dim == pos_1d)[0]
43
            init_pos_conv_list.append(init_pos_conv)
44
45
        return init_pos_conv_list
46
47
    def _positions2values(self, positions):
48
        pos_converted = []
49
        positions_np = np.array(positions)
50
51
        for n, space_dim in enumerate(self.search_space):
52
            pos_1d = positions_np[:, n]
53
            pos_conv = np.take(space_dim, pos_1d, axis=0)
54
            pos_converted.append(pos_conv)
55
56
        return list(np.array(pos_converted).T)
57
58
    def _init_values(self, init_values):
59
        init_positions_list = []
60
61
        if "random" in init_values:
62
            positions = init_random_search(self.space_dim, init_values["random"])
63
            init_positions_list.append(positions)
64
        if "grid" in init_values:
65
            positions = init_grid_search(self.space_dim, init_values["grid"])
66
            init_positions_list.append(positions)
67
        if "warm_start" in init_values:
68
            positions = self._values2positions(init_values["warm_start"])
69
            init_positions_list.append(positions)
70
71
        return [item for sublist in init_positions_list for item in sublist]
72
73
    def _position2value(self, position):
74
        value = []
75
76
        for n, space_dim in enumerate(self.search_space):
77
            value.append(space_dim[position[n]])
78
79
        return value
80
81
    def _score_mem(self, pos):
82
        pos_tuple = tuple(pos)
83
84
        if pos_tuple in self.memory_dict:
85
            return self.memory_dict[pos_tuple]
86
        else:
87
            score = self.objective_function(pos)
88
            self.memory_dict[pos_tuple] = score
89
            return score
90
91
    def _init_memory(self, memory):
92
        if memory == False:
93
            self._score = self.objective_function
94
        elif memory == True:
95
            self._score = self._score_mem
96
            self.memory_dict = {}
97
        elif isinstance(memory, dict):
98
            self._score = self._score_mem
99
100
            values_list = memory["values"]
101
            scores = memory["scores"]
102
103
            value_tuple_list = list(map(tuple, values_list))
104
            self.memory_dict = dict(zip(value_tuple_list, scores))
105
106
    def search(
107
        self,
108
        objective_function,
109
        n_iter,
110
        initialize={"grid": 7, "random": 3,},
111
        max_time=None,
112
        memory=True,
113
        verbosity=1,
114
        random_state=None,
115
        nth_process=0,
116
    ):
117
        self.objective_function = objective_function
118
        self._init_memory(memory)
119
120
        set_random_seed(nth_process, random_state)
121
        start_time = time.time()
122
123
        self.p_bar = p_bar_dict[verbosity](nth_process, n_iter, objective_function)
124
125
        init_values = self._init_values(initialize)
126
127
        # loop to initialize N positions
128
        for init_position in init_values:
129
            start_time_iter = time.time()
130
            self.init_pos(init_position)
131
132
            start_time_eval = time.time()
133
            score_new = self._score(init_position)
134
            self.p_bar.update(1, score_new)
135
            self.eval_times.append(time.time() - start_time_eval)
136
137
            self.evaluate(score_new)
138
            self.iter_times.append(time.time() - start_time_iter)
139
140
        # loop to do the iterations
141
        for nth_iter in range(len(init_values), n_iter):
142
            start_time_iter = time.time()
143
            pos_new = self.iterate()
144
145
            value_new = self._position2value(pos_new)
146
147
            start_time_eval = time.time()
148
            score_new = self._score(value_new)
149
            self.p_bar.update(1, score_new)
150
            self.eval_times.append(time.time() - start_time_eval)
151
152
            self.evaluate(score_new)
153
            self.iter_times.append(time.time() - start_time_iter)
154
155
            if time_exceeded(start_time, max_time):
156
                break
157
158
        self.p_bar.close()
159
160