Passed
Push — master ( d39371...69bf6f )
by Simon
03:38
created

gradient_free_optimizers.search   A

Complexity

Total Complexity 15

Size/Duplication

Total Lines 122
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 84
dl 0
loc 122
rs 10
c 0
b 0
f 0
wmc 15

4 Methods

Rating   Name   Duplication   Size   Complexity  
A Search._values2positions() 0 10 2
A Search.search() 0 48 4
A Search._positions2values() 0 10 2
A Search._init_positions() 0 14 4

2 Functions

Rating   Name   Duplication   Size   Complexity  
A time_exceeded() 0 3 1
A set_random_seed() 0 7 2
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import random
7
8
import numpy as np
9
from tqdm import tqdm
10
11
from .init_positions import init_grid_search, init_random_search
12
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
13
from .io_search_processor import init_search
14
15
16
p_bar_dict = {
17
    0: ProgressBarLVL0,
18
    1: ProgressBarLVL1,
19
}
20
21
22
def time_exceeded(start_time, max_time):
23
    run_time = time.time() - start_time
24
    return max_time and run_time > max_time
25
26
27
def set_random_seed(nth_process, random_state):
28
    """Sets the random seed separately for each thread (to avoid getting the same results in each thread)"""
29
    if random_state is None:
30
        random_state = np.random.randint(0, high=2 ** 32 - 2)
31
32
    random.seed(random_state + nth_process)
33
    np.random.seed(random_state + nth_process)
34
35
36
class Search:
37
    def _values2positions(self, values):
38
        init_pos_conv_list = []
39
        values_np = np.array(values)
40
41
        for n, space_dim in enumerate(self.search_space):
42
            pos_1d = values_np[:, n]
43
            init_pos_conv = np.where(space_dim == pos_1d)[0]
44
            init_pos_conv_list.append(init_pos_conv)
45
46
        return init_pos_conv_list
47
48
    def _positions2values(self, positions):
49
        pos_converted = []
50
        positions_np = np.array(positions)
51
52
        for n, space_dim in enumerate(self.search_space):
53
            pos_1d = positions_np[:, n]
54
            pos_conv = np.take(space_dim, pos_1d, axis=0)
55
            pos_converted.append(pos_conv)
56
57
        return list(np.array(pos_converted).T)
58
59
    def _init_positions(self, init_values):
60
        init_positions_list = []
61
62
        if "random" in init_values:
63
            positions = init_random_search(self.space_dim, init_values["random"])
64
            init_positions_list.append(positions)
65
        if "grid" in init_values:
66
            positions = init_grid_search(self.space_dim, init_values["grid"])
67
            init_positions_list.append(positions)
68
        if "warm_start" in init_values:
69
            positions = self._values2positions(init_values["warm_start"])
70
            init_positions_list.append(positions)
71
72
        return [item for sublist in init_positions_list for item in sublist]
73
74
    def search(
75
        self,
76
        objective_function,
77
        n_iter,
78
        init_values={"grid": 7, "random": 3,},
79
        max_time=None,
80
        verbosity=1,
81
        random_state=None,
82
        nth_process=0,
83
    ):
84
        set_random_seed(nth_process, random_state)
85
        start_time = time.time()
86
87
        self.p_bar = p_bar_dict[verbosity]()
88
        self.p_bar.init(nth_process, n_iter, objective_function)
89
90
        init_positions = self._init_positions(init_values)
91
92
        # loop to initialize N positions
93
        for init_position in init_positions:
94
            start_time_iter = time.time()
95
            self.init_pos(init_position)
96
97
            start_time_eval = time.time()
98
            score_new = objective_function(init_position)
99
            self.p_bar.update(1, score_new)
100
            self.eval_times.append(time.time() - start_time_eval)
101
102
            self.evaluate(score_new)
103
            self.iter_times.append(time.time() - start_time_iter)
104
105
        # loop to do the iterations
106
        for nth_iter in range(len(init_positions), n_iter):
107
            start_time_iter = time.time()
108
            pos_new = self.iterate()
109
110
            start_time_eval = time.time()
111
            score_new = objective_function(pos_new)
112
            self.p_bar.update(1, score_new)
113
            self.eval_times.append(time.time() - start_time_eval)
114
115
            self.evaluate(score_new)
116
            self.iter_times.append(time.time() - start_time_iter)
117
118
            if time_exceeded(start_time, max_time):
119
                break
120
121
        self.p_bar.close()
122
123