Passed
Push — master ( 4ef66c...afb360 )
by Simon
03:52
created

gradient_free_optimizers.search.Search.search()   B

Complexity

Conditions 5

Size

Total Lines 45
Code Lines 32

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 32
nop 10
dl 0
loc 45
rs 8.6453
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import random
7
8
import numpy as np
9
from tqdm import tqdm
10
11
from .init_positions import Initializer
12
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
13
from .conv import values2positions, positions2values, position2value
14
from .times_tracker import TimesTracker
15
16
p_bar_dict = {
17
    False: ProgressBarLVL0,
18
    True: ProgressBarLVL1,
19
}
20
21
22
def time_exceeded(start_time, max_time):
23
    run_time = time.time() - start_time
24
    return max_time and run_time > max_time
25
26
27
def set_random_seed(nth_process, random_state):
28
    """Sets the random seed separately for each thread (to avoid getting the same results in each thread)"""
29
    if nth_process is None:
30
        nth_process = 0
31
32
    if random_state is None:
33
        random_state = np.random.randint(0, high=2 ** 32 - 2)
34
35
    random.seed(random_state + nth_process)
36
    np.random.seed(random_state + nth_process)
37
38
39
class Search(TimesTracker):
40
    def __init__(self):
41
        super().__init__()
42
43
        self.optimizers = []
44
45
    @TimesTracker.eval_time_dec
46
    def _score(self, pos):
47
        pos_tuple = tuple(pos)
48
49
        if self.memory is True and pos_tuple in self.memory_dict:
50
            return self.memory_dict[pos_tuple]
51
        else:
52
            score = self.objective_function(pos)
53
            self.memory_dict[pos_tuple] = score
54
            return score
55
56
    def _init_memory(self, memory):
57
        self.memory_dict = {}
58
59
        if isinstance(memory, dict):
60
            values_list = memory["values"]
61
            scores = memory["scores"]
62
63
            value_tuple_list = list(map(tuple, values_list))
64
            self.memory_dict = dict(zip(value_tuple_list, scores))
65
66
    @TimesTracker.iter_time_dec
67
    def _initialization(self, init_pos):
68
        self.init_pos(init_pos)
69
70
        value_new = position2value(self.search_space, init_pos)
71
        score_new = self._score(value_new)
72
        self.evaluate(score_new)
73
74
        self.p_bar.update(score_new, value_new)
75
76
    @TimesTracker.iter_time_dec
77
    def _iteration(self):
78
        pos_new = self.iterate()
79
80
        value_new = position2value(self.search_space, pos_new)
81
        score_new = self._score(value_new)
82
        self.evaluate(score_new)
83
84
        self.p_bar.update(score_new, value_new)
85
86
    def search(
87
        self,
88
        objective_function,
89
        n_iter,
90
        initialize={"grid": 4, "random": 2, "vertices": 4},
91
        max_time=None,
92
        memory=True,
93
        progress_bar=True,
94
        print_results=True,
95
        random_state=None,
96
        nth_process=None,
97
    ):
98
        start_time = time.time()
99
100
        self.objective_function = objective_function
101
        self.memory = memory
102
103
        self._init_memory(memory)
104
        self.p_bar = p_bar_dict[progress_bar](nth_process, n_iter, objective_function)
105
106
        set_random_seed(nth_process, random_state)
107
108
        # get init positions
109
        init = Initializer(self.search_space)
110
        init_positions = init.set_pos(initialize)
111
112
        # loop to initialize N positions
113
        for init_pos in init_positions:
114
            if time_exceeded(start_time, max_time):
115
                break
116
            self._initialization(init_pos)
117
118
        # loop to do the iterations
119
        for nth_iter in range(len(init_positions), n_iter):
120
            if time_exceeded(start_time, max_time):
121
                break
122
            self._iteration()
123
124
        self.values = np.array(list(self.memory_dict.keys()))
125
        self.scores = np.array(list(self.memory_dict.values())).reshape(-1, 1)
126
127
        self.p_bar.close(print_results)
128
129
        self.best_score = self.p_bar.score_best
130
        self.best_values = self.p_bar.values_best
131
132