Passed
Push — master ( bac9d4...455948 )
by Simon
04:25
created

gradient_free_optimizers.search.score_exceeded()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import random
7
8
import numpy as np
9
import pandas as pd
10
11
from .init_positions import Initializer
12
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
13
from .times_tracker import TimesTracker
14
from .results_manager import ResultsManager
15
from .memory import Memory
16
from .print_info import print_info
17
18
p_bar_dict = {
19
    False: ProgressBarLVL0,
20
    True: ProgressBarLVL1,
21
}
22
23
24
def time_exceeded(start_time, max_time):
25
    run_time = time.time() - start_time
26
    return max_time and run_time > max_time
27
28
29
def score_exceeded(score_best, max_score):
30
    return max_score and score_best >= max_score
31
32
33
def set_random_seed(nth_process, random_state):
34
    """
35
    Sets the random seed separately for each thread
36
    (to avoid getting the same results in each thread)
37
    """
38
    if nth_process is None:
39
        nth_process = 0
40
41
    if random_state is None:
42
        random_state = np.random.randint(0, high=2 ** 32 - 2)
43
44
    random.seed(random_state + nth_process)
45
    np.random.seed(random_state + nth_process)
46
47
48
class Search(TimesTracker):
49
    def __init__(self):
50
        super().__init__()
51
52
        self.optimizers = []
53
        self.new_results_list = []
54
        self.all_results_list = []
55
56
    @TimesTracker.eval_time
57
    def _score(self, pos):
58
        return self.score(pos)
59
60
    def _init_memory(self, memory):
61
        memory_warm_start = self.memory_warm_start
62
63
        self.memory_dict = {}
64
        self.memory_dict_new = {}
65
66
        if isinstance(memory_warm_start, pd.DataFrame):
67
            parameter = set(self.conv.search_space.keys())
68
            memory_para = set(memory_warm_start.columns)
69
70
            if parameter <= memory_para:
71
                values_list = list(
72
                    memory_warm_start[
73
                        list(self.conv.search_space.keys())
74
                    ].values
75
                )
76
                scores = memory_warm_start["score"]
77
78
                value_tuple_list = list(map(tuple, values_list))
79
                self.memory_dict = dict(zip(value_tuple_list, scores))
80
            else:
81
                missing = parameter - memory_para
82
83
                print(
84
                    "\nWarning:",
85
                    '"{}"'.format(*missing),
86
                    "is in search_space but not in memory dataframe",
87
                )
88
                print(
89
                    "Optimization run will continue "
90
                    "without memory warm start\n"
91
                )
92
93
    @TimesTracker.iter_time
94
    def _initialization(self, init_pos):
95
        self.init_pos(init_pos)
96
97
        score_new = self._score(init_pos)
98
        self.evaluate(score_new)
99
100
        self.p_bar.update(score_new, init_pos)
101
102
    @TimesTracker.iter_time
103
    def _iteration(self):
104
        pos_new = self.iterate()
105
106
        score_new = self._score(pos_new)
107
        self.evaluate(score_new)
108
109
        self.p_bar.update(score_new, pos_new)
110
111
    def _init_search(self):
112
        self._init_memory(self.memory)
113
        self.p_bar = p_bar_dict[self.progress_bar](
114
            self.nth_process, self.n_iter, self.objective_function
115
        )
116
        set_random_seed(self.nth_process, self.random_state)
117
118
        if self.warm_start is not None:
119
            self.initialize["warm_start"] = self.warm_start
120
121
        # get init positions
122
        init = Initializer(self.conv)
123
        init_positions = init.set_pos(self.initialize)
124
125
        return init_positions
126
127
    def _early_stop(self):
128
        if time_exceeded(self.start_time, self.max_time):
129
            return True
130
        elif score_exceeded(self.p_bar.score_best, self.max_score):
131
            return True
132
        else:
133
            return False
134
135
    def _init_verb_dict(self, verb_dict):
136
        if verb_dict in [None, False]:
137
            return {
138
                "progress_bar": False,
139
                "print_results": False,
140
                "print_times": False,
141
            }
142
143
        verb_default = {
144
            "progress_bar": True,
145
            "print_results": True,
146
            "print_times": True,
147
        }
148
149
        for verb_key in verb_default.keys():
150
            if verb_key not in verb_dict:
151
                verb_dict[verb_key] = verb_default[verb_key]
152
153
        return verb_dict
154
155
    def print_info(self, *args):
156
        print_info(*args)
157
158
    def search(
159
        self,
160
        objective_function,
161
        n_iter,
162
        initialize={"grid": 8, "random": 4, "vertices": 8},
163
        warm_start=None,
164
        max_time=None,
165
        max_score=None,
166
        memory=True,
167
        memory_warm_start=None,
168
        verbosity={
169
            "progress_bar": True,
170
            "print_results": True,
171
            "print_times": True,
172
        },
173
        random_state=None,
174
        nth_process=None,
175
    ):
176
177
        self.start_time = time.time()
178
179
        verbosity = self._init_verb_dict(verbosity)
180
181
        self.objective_function = objective_function
182
        self.n_iter = n_iter
183
        self.initialize = initialize
184
        self.warm_start = warm_start
185
        self.max_time = max_time
186
        self.max_score = max_score
187
        self.memory = memory
188
        self.memory_warm_start = memory_warm_start
189
        self.progress_bar = verbosity["progress_bar"]
190
        self.random_state = random_state
191
        self.nth_process = nth_process
192
193
        results = ResultsManager(objective_function, self.conv)
194
        init_positions = self._init_search()
195
196
        if memory is True:
197
            mem = Memory(memory_warm_start, self.conv)
198
            self.score = mem.memory(results.score)
199
        else:
200
            self.score = results.score
201
202
        # loop to initialize N positions
203
        for init_pos, nth_iter in zip(init_positions, range(n_iter)):
204
            if self._early_stop():
205
                break
206
            self._initialization(init_pos)
207
208
        # loop to do the iterations
209
        for nth_iter in range(len(init_positions), n_iter):
210
            if self._early_stop():
211
                break
212
            self._iteration()
213
214
        self.results = pd.DataFrame(results.results_list)
215
216
        self.best_score = self.p_bar.score_best
217
        self.best_value = self.conv.position2value(self.p_bar.pos_best)
218
        self.best_para = self.conv.value2para(self.best_value)
219
220
        eval_time = np.array(self.eval_times).sum()
221
        iter_time = np.array(self.iter_times).sum()
222
223
        self.p_bar.close()
224
225
        self.print_info(
226
            verbosity,
227
            self.objective_function,
228
            self.best_score,
229
            self.best_para,
230
            eval_time,
231
            iter_time,
232
            self.n_iter,
233
        )
234
235