Passed
Push — master ( 7aa4c7...2b6f90 )
by Simon
04:34
created

Search.print_info()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import random
7
8
import numpy as np
9
import pandas as pd
10
11
from .init_positions import Initializer
12
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
13
from .times_tracker import TimesTracker
14
from .memory import Memory
15
from .print_info import print_info
16
17
18
def time_exceeded(start_time, max_time):
19
    run_time = time.time() - start_time
20
    return max_time and run_time > max_time
21
22
23
def score_exceeded(score_best, max_score):
24
    return max_score and score_best >= max_score
25
26
27
def set_random_seed(nth_process, random_state):
28
    """
29
    Sets the random seed separately for each thread
30
    (to avoid getting the same results in each thread)
31
    """
32
    if nth_process is None:
33
        nth_process = 0
34
35
    if random_state is None:
36
        random_state = np.random.randint(0, high=2 ** 32 - 2)
37
38
    random.seed(random_state + nth_process)
39
    np.random.seed(random_state + nth_process)
40
41
42
class Search(TimesTracker):
43
    def __init__(self):
44
        super().__init__()
45
46
        self.optimizers = []
47
        self.new_results_list = []
48
        self.all_results_list = []
49
50
    @TimesTracker.eval_time
51
    def _score(self, pos):
52
        return self.score(pos)
53
54
    @TimesTracker.iter_time
55
    def _initialization(self, init_pos, nth_iter):
56
        self.init_pos(init_pos)
57
58
        score_new = self._score(init_pos)
59
        self.evaluate(score_new)
60
61
        self.p_bar.update(score_new, init_pos, nth_iter)
62
63
    @TimesTracker.iter_time
64
    def _iteration(self, nth_iter):
65
        pos_new = self.iterate()
66
67
        score_new = self._score(pos_new)
68
        self.evaluate(score_new)
69
70
        self.p_bar.update(score_new, pos_new, nth_iter)
71
72
    def _init_search(self):
73
        if "progress_bar" in self.verbosity:
74
            self.p_bar = ProgressBarLVL1(
75
                self.nth_process, self.n_iter, self.objective_function
76
            )
77
        else:
78
            self.p_bar = ProgressBarLVL0(
79
                self.nth_process, self.n_iter, self.objective_function
80
            )
81
82
        set_random_seed(self.nth_process, self.random_state)
83
84
        # get init positions
85
        init = Initializer(self.conv)
86
        init_positions = init.set_pos(self.initialize)
87
88
        return init_positions
89
90
    def _early_stop(self):
91
        if time_exceeded(self.start_time, self.max_time):
92
            return True
93
        elif score_exceeded(self.p_bar.score_best, self.max_score):
94
            return True
95
        else:
96
            return False
97
98
    def _init_verb_dict(self, verb_dict):
99
        if verb_dict in [None, False]:
100
            return {
101
                "progress_bar": False,
102
                "print_results": False,
103
                "print_times": False,
104
            }
105
106
        verb_default = {
107
            "progress_bar": True,
108
            "print_results": True,
109
            "print_times": True,
110
        }
111
112
        for verb_key in verb_default.keys():
113
            if verb_key not in verb_dict:
114
                verb_dict[verb_key] = verb_default[verb_key]
115
116
        return verb_dict
117
118
    def print_info(self, *args):
119
        print_info(*args)
120
121
    def search(
122
        self,
123
        objective_function,
124
        n_iter,
125
        max_time=None,
126
        max_score=None,
127
        memory=True,
128
        memory_warm_start=None,
129
        verbosity=["progress_bar", "print_results", "print_times"],
130
        random_state=None,
131
        nth_process=None,
132
    ):
133
        self.start_time = time.time()
134
135
        if verbosity is False:
136
            verbosity = []
137
138
        self.objective_function = objective_function
139
        self.n_iter = n_iter
140
        self.max_time = max_time
141
        self.max_score = max_score
142
        self.memory = memory
143
        self.memory_warm_start = memory_warm_start
144
        self.verbosity = verbosity
145
        self.random_state = random_state
146
        self.nth_process = nth_process
147
148
        init_positions = self._init_search()
149
150
        if memory is True:
151
            mem = Memory(memory_warm_start, self.conv)
152
            self.score = self.results_mang.score(
153
                mem.memory(objective_function)
154
            )
155
        else:
156
            self.score = self.results_mang.score(objective_function)
157
158
        # loop to initialize N positions
159
        for init_pos, nth_iter in zip(init_positions, range(n_iter)):
160
            if self._early_stop():
161
                break
162
            self._initialization(init_pos, nth_iter)
163
164
        # loop to do the iterations
165
        for nth_iter in range(len(init_positions), n_iter):
166
            if self._early_stop():
167
                break
168
            self._iteration(nth_iter)
169
170
        self.results = pd.DataFrame(self.results_mang.results_list)
171
172
        self.best_score = self.p_bar.score_best
173
        self.best_value = self.conv.position2value(self.p_bar.pos_best)
174
        self.best_para = self.conv.value2para(self.best_value)
175
176
        self.results["eval_time"] = self.eval_times
177
        self.results["iter_time"] = self.iter_times
178
179
        if memory is not False:
180
            self.memory_dict = mem.memory_dict
0 ignored issues
show
introduced by
The variable mem does not seem to be defined for all execution paths.
Loading history...
181
        else:
182
            self.memory_dict = {}
183
184
        self.p_bar.close()
185
186
        self.print_info(
187
            verbosity,
188
            self.objective_function,
189
            self.best_score,
190
            self.best_para,
191
            self.eval_times,
192
            self.iter_times,
193
            self.n_iter,
194
        )
195
196