Passed
Push — master ( 777cc3...c97f4f )
by Simon
01:20
created

gradient_free_optimizers.search.Search._score()   A

Complexity

Conditions 3

Size

Total Lines 9
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 7
nop 2
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import random
7
8
import numpy as np
9
from tqdm import tqdm
10
11
from .init_positions import init_grid_search, init_random_search
12
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
13
14
15
p_bar_dict = {
16
    0: ProgressBarLVL0,
17
    1: ProgressBarLVL1,
18
}
19
20
21
def time_exceeded(start_time, max_time):
22
    run_time = time.time() - start_time
23
    return max_time and run_time > max_time
24
25
26
def set_random_seed(nth_process, random_state):
27
    """Sets the random seed separately for each thread (to avoid getting the same results in each thread)"""
28
    if random_state is None:
29
        random_state = np.random.randint(0, high=2 ** 32 - 2)
30
31
    random.seed(random_state + nth_process)
32
    np.random.seed(random_state + nth_process)
33
34
35
class Search:
36
    def _values2positions(self, values):
37
        init_pos_conv_list = []
38
        values_np = np.array(values)
39
40
        for n, space_dim in enumerate(self.search_space):
41
            pos_1d = values_np[:, n]
42
            init_pos_conv = np.where(space_dim == pos_1d)[0]
43
            init_pos_conv_list.append(init_pos_conv)
44
45
        return init_pos_conv_list
46
47
    def _positions2values(self, positions):
48
        pos_converted = []
49
        positions_np = np.array(positions)
50
51
        for n, space_dim in enumerate(self.search_space):
52
            pos_1d = positions_np[:, n]
53
            pos_conv = np.take(space_dim, pos_1d, axis=0)
54
            pos_converted.append(pos_conv)
55
56
        return list(np.array(pos_converted).T)
57
58
    def _init_positions(self, init_values):
59
        init_positions_list = []
60
61
        if "random" in init_values:
62
            positions = init_random_search(self.space_dim, init_values["random"])
63
            init_positions_list.append(positions)
64
        if "grid" in init_values:
65
            positions = init_grid_search(self.space_dim, init_values["grid"])
66
            init_positions_list.append(positions)
67
        if "warm_start" in init_values:
68
            positions = self._values2positions(init_values["warm_start"])
69
            init_positions_list.append(positions)
70
71
        return [item for sublist in init_positions_list for item in sublist]
72
73
    def _score(self, pos):
74
        pos_tuple = tuple(pos)
75
76
        if self.memory and pos_tuple in self.memory_dict:
77
            return self.memory_dict[pos_tuple]
78
        else:
79
            score = self.objective_function(pos)
80
            self.memory_dict[pos_tuple] = score
81
            return score
82
83
    def search(
84
        self,
85
        objective_function,
86
        n_iter,
87
        init_values={"grid": 7, "random": 3,},
88
        max_time=None,
89
        memory=True,
90
        verbosity=1,
91
        random_state=None,
92
        nth_process=0,
93
    ):
94
        self.objective_function = objective_function
95
        self.memory = memory
96
        self.memory_dict = {}
97
98
        set_random_seed(nth_process, random_state)
99
        start_time = time.time()
100
101
        self.p_bar = p_bar_dict[verbosity]()
102
        self.p_bar.init(nth_process, n_iter, objective_function)
103
104
        init_positions = self._init_positions(init_values)
105
106
        # loop to initialize N positions
107
        for init_position in init_positions:
108
            start_time_iter = time.time()
109
            self.init_pos(init_position)
110
111
            start_time_eval = time.time()
112
            score_new = self._score(init_position)
113
            self.p_bar.update(1, score_new)
114
            self.eval_times.append(time.time() - start_time_eval)
115
116
            self.evaluate(score_new)
117
            self.iter_times.append(time.time() - start_time_iter)
118
119
        # loop to do the iterations
120
        for nth_iter in range(len(init_positions), n_iter):
121
            start_time_iter = time.time()
122
            pos_new = self.iterate()
123
124
            start_time_eval = time.time()
125
            score_new = self._score(pos_new)
126
            self.p_bar.update(1, score_new)
127
            self.eval_times.append(time.time() - start_time_eval)
128
129
            self.evaluate(score_new)
130
            self.iter_times.append(time.time() - start_time_iter)
131
132
            if time_exceeded(start_time, max_time):
133
                break
134
135
        self.p_bar.close()
136
137