Passed
Push — master ( b0b7f3...e3d174 )
by Simon
04:31
created

Search._init_memory()   A

Complexity

Conditions 3

Size

Total Lines 28
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 20
nop 2
dl 0
loc 28
rs 9.4
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import random
7
8
import numpy as np
9
import pandas as pd
10
11
from .init_positions import Initializer
12
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
13
from .times_tracker import TimesTracker
14
from .results_manager import ResultsManager
15
from .memory import Memory
16
from .converter import Converter
17
from .print_info import print_info
18
19
p_bar_dict = {
20
    False: ProgressBarLVL0,
21
    True: ProgressBarLVL1,
22
}
23
24
25
def time_exceeded(start_time, max_time):
26
    run_time = time.time() - start_time
27
    return max_time and run_time > max_time
28
29
30
def score_exceeded(score_best, max_score):
31
    return max_score and score_best >= max_score
32
33
34
def set_random_seed(nth_process, random_state):
35
    """
36
    Sets the random seed separately for each thread
37
    (to avoid getting the same results in each thread)
38
    """
39
    if nth_process is None:
40
        nth_process = 0
41
42
    if random_state is None:
43
        random_state = np.random.randint(0, high=2 ** 32 - 2)
44
45
    random.seed(random_state + nth_process)
46
    np.random.seed(random_state + nth_process)
47
48
49
class Search(TimesTracker):
50
    def __init__(self):
51
        super().__init__()
52
53
        self.optimizers = []
54
        self.new_results_list = []
55
        self.all_results_list = []
56
57
    @TimesTracker.eval_time
58
    def _score(self, pos):
59
        return self.score(pos)
60
61
    def _init_memory(self, memory):
62
        memory_warm_start = self.memory_warm_start
63
64
        self.memory_dict = {}
65
        self.memory_dict_new = {}
66
67
        if isinstance(memory_warm_start, pd.DataFrame):
68
            parameter = set(self.search_space.keys())
69
            memory_para = set(memory_warm_start.columns)
70
71
            if parameter <= memory_para:
72
                values_list = list(
73
                    memory_warm_start[list(self.search_space.keys())].values
74
                )
75
                scores = memory_warm_start["score"]
76
77
                value_tuple_list = list(map(tuple, values_list))
78
                self.memory_dict = dict(zip(value_tuple_list, scores))
79
            else:
80
                missing = parameter - memory_para
81
82
                print(
83
                    "\nWarning:",
84
                    '"{}"'.format(*missing),
85
                    "is in search_space but not in memory dataframe",
86
                )
87
                print(
88
                    "Optimization run will continue "
89
                    "without memory warm start\n"
90
                )
91
92
    @TimesTracker.iter_time
93
    def _initialization(self, init_pos):
94
        self.init_pos(init_pos)
95
96
        score_new = self._score(init_pos)
97
        self.evaluate(score_new)
98
99
        self.p_bar.update(score_new, init_pos)
100
101
    @TimesTracker.iter_time
102
    def _iteration(self):
103
        pos_new = self.iterate()
104
105
        score_new = self._score(pos_new)
106
        self.evaluate(score_new)
107
108
        self.p_bar.update(score_new, pos_new)
109
110
    def _init_search(self):
111
        self._init_memory(self.memory)
112
        self.p_bar = p_bar_dict[self.progress_bar](
113
            self.nth_process, self.n_iter, self.objective_function
114
        )
115
        set_random_seed(self.nth_process, self.random_state)
116
117
        if self.warm_start is not None:
118
            self.initialize["warm_start"] = self.warm_start
119
120
        # get init positions
121
        init = Initializer(self.conv)
122
        init_positions = init.set_pos(self.initialize)
123
124
        return init_positions
125
126
    def _early_stop(self):
127
        if time_exceeded(self.start_time, self.max_time):
128
            return True
129
        elif score_exceeded(self.p_bar.score_best, self.max_score):
130
            return True
131
        else:
132
            return False
133
134
    def _init_verb_dict(self, verb_dict):
135
        if verb_dict in [None, False]:
136
            return {
137
                "progress_bar": False,
138
                "print_results": False,
139
                "print_times": False,
140
            }
141
142
        verb_default = {
143
            "progress_bar": True,
144
            "print_results": True,
145
            "print_times": True,
146
        }
147
148
        for verb_key in verb_default.keys():
149
            if verb_key not in verb_dict:
150
                verb_dict[verb_key] = verb_default[verb_key]
151
152
        return verb_dict
153
154
    def search(
155
        self,
156
        objective_function,
157
        n_iter,
158
        initialize={"grid": 8, "random": 4, "vertices": 8},
159
        warm_start=None,
160
        max_time=None,
161
        max_score=None,
162
        memory=True,
163
        memory_warm_start=None,
164
        verbosity={
165
            "progress_bar": True,
166
            "print_results": True,
167
            "print_times": True,
168
        },
169
        random_state=None,
170
        nth_process=None,
171
    ):
172
        self.start_time = time.time()
173
174
        verbosity = self._init_verb_dict(verbosity)
175
176
        self.objective_function = objective_function
177
        self.n_iter = n_iter
178
        self.initialize = initialize
179
        self.warm_start = warm_start
180
        self.max_time = max_time
181
        self.max_score = max_score
182
        self.memory = memory
183
        self.memory_warm_start = memory_warm_start
184
        self.progress_bar = verbosity["progress_bar"]
185
        self.random_state = random_state
186
        self.nth_process = nth_process
187
188
        conv = Converter(self.search_space)
189
        self.conv = conv
190
        results = ResultsManager(objective_function, conv)
191
        init_positions = self._init_search()
192
193
        if memory is True:
194
            mem = Memory(memory_warm_start, conv)
195
            self.score = mem.memory(results.score)
196
        else:
197
            self.score = results.score
198
199
        # loop to initialize N positions
200
        for init_pos, nth_iter in zip(init_positions, range(n_iter)):
201
            if self._early_stop():
202
                break
203
            self._initialization(init_pos)
204
205
        # loop to do the iterations
206
        for nth_iter in range(len(init_positions), n_iter):
207
            if self._early_stop():
208
                break
209
            self._iteration()
210
211
        self.results = pd.DataFrame(results.results_list)
212
213
        self.best_score = self.p_bar.score_best
214
        self.best_value = conv.position2value(self.p_bar.pos_best)
215
        self.best_para = conv.value2para(self.best_value)
216
217
        eval_time = np.array(self.eval_times).sum()
218
        iter_time = np.array(self.iter_times).sum()
219
220
        self.p_bar.close()
221
222
        print_info(
223
            verbosity,
224
            self.objective_function,
225
            self.best_score,
226
            self.best_para,
227
            eval_time,
228
            iter_time,
229
            self.n_iter,
230
        )
231
232