Passed
Push — master ( e37ecf...2614a9 )
by Simon
04:19
created

Search.search_step()   A

Complexity

Conditions 4

Size

Total Lines 13
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 8
nop 2
dl 0
loc 13
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
7
from multiprocessing.managers import DictProxy
8
9
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
10
from .times_tracker import TimesTracker
11
from .search_statistics import SearchStatistics
12
from .memory import Memory
13
from .print_info import print_info
14
from .stop_run import StopRun
15
16
from .results_manager import ResultsManager
17
18
19
class Search(TimesTracker, SearchStatistics):
20
    def __init__(self):
21
        super().__init__()
22
23
        self.optimizers = []
24
        self.new_results_list = []
25
        self.all_results_list = []
26
27
        self.score_l = []
28
        self.pos_l = []
29
        self.random_seed = None
30
31
        self.search_state = "init"
32
33
        self.results_mang = ResultsManager()
34
35
    @TimesTracker.eval_time
36
    def _score(self, pos):
37
        return self.score(pos)
38
39
    @TimesTracker.iter_time
40
    def _initialization(self):
41
        self.best_score = self.p_bar.score_best
42
43
        init_pos = self.init_pos()
44
45
        score_new = self._score(init_pos)
46
        self.evaluate_init(score_new)
47
48
        self.pos_l.append(init_pos)
49
        self.score_l.append(score_new)
50
51
        self.p_bar.update(score_new, init_pos, self.nth_iter)
52
53
        self.n_init_total += 1
54
        self.n_init_search += 1
55
56
        self.stop.update(self.p_bar.score_best, self.score_l)
57
58
    @TimesTracker.iter_time
59
    def _iteration(self):
60
        self.best_score = self.p_bar.score_best
61
62
        pos_new = self.iterate()
63
64
        score_new = self._score(pos_new)
65
        self.evaluate(score_new)
66
67
        self.pos_l.append(pos_new)
68
        self.score_l.append(score_new)
69
70
        self.p_bar.update(score_new, pos_new, self.nth_iter)
71
72
        self.n_iter_total += 1
73
        self.n_iter_search += 1
74
75
        self.stop.update(self.p_bar.score_best, self.score_l)
76
77
    def search(
78
        self,
79
        objective_function,
80
        n_iter,
81
        max_time=None,
82
        max_score=None,
83
        early_stopping=None,
84
        memory=True,
85
        memory_warm_start=None,
86
        verbosity=["progress_bar", "print_results", "print_times"],
87
    ):
88
        self.init_search(
89
            objective_function,
90
            n_iter,
91
            max_time,
92
            max_score,
93
            early_stopping,
94
            memory,
95
            memory_warm_start,
96
            verbosity,
97
        )
98
99
        for nth_iter in range(n_iter):
100
            self.search_step(nth_iter)
101
            if self.stop.check():
102
                break
103
104
        self.finish_search()
105
106
    @SearchStatistics.init_stats
107
    def init_search(
108
        self,
109
        objective_function,
110
        n_iter,
111
        max_time,
112
        max_score,
113
        early_stopping,
114
        memory,
115
        memory_warm_start,
116
        verbosity,
117
    ):
118
        self.objective_function = objective_function
119
        self.n_iter = n_iter
120
        self.max_time = max_time
121
        self.max_score = max_score
122
        self.early_stopping = early_stopping
123
        self.memory = memory
124
        self.memory_warm_start = memory_warm_start
125
        self.verbosity = verbosity
126
127
        self.results_mang.conv = self.conv
128
129
        if self.verbosity is False:
130
            self.verbosity = []
131
132
        start_time = time.time()
133
        self.stop = StopRun(
134
            start_time, self.max_time, self.max_score, self.early_stopping
135
        )
136
137
        if "progress_bar" in self.verbosity:
138
            self.p_bar = ProgressBarLVL1(
139
                self.nth_process, self.n_iter, self.objective_function
140
            )
141
        else:
142
            self.p_bar = ProgressBarLVL0(
143
                self.nth_process, self.n_iter, self.objective_function
144
            )
145
146
        if isinstance(self.memory, DictProxy):
147
            self.mem = Memory(self.memory_warm_start, self.conv, dict_proxy=self.memory)
148
            self.score = self.results_mang.score(
149
                self.mem.memory(self.objective_function)
150
            )
151
        elif self.memory is True:
152
            self.mem = Memory(self.memory_warm_start, self.conv)
153
            self.score = self.results_mang.score(
154
                self.mem.memory(self.objective_function)
155
            )
156
        else:
157
            self.score = self.results_mang.score(self.objective_function)
158
159
        self.n_inits_norm = min((self.init.n_inits - self.n_init_total), self.n_iter)
160
161
    def finish_search(self):
162
        self.search_data = self.results_mang.search_data
163
164
        self.best_score = self.p_bar.score_best
165
        self.best_value = self.conv.position2value(self.p_bar.pos_best)
166
        self.best_para = self.conv.value2para(self.best_value)
167
168
        if self.memory not in [False, None]:
169
            self.memory_dict = self.mem.memory_dict
170
        else:
171
            self.memory_dict = {}
172
173
        self.p_bar.close()
174
175
        print_info(
176
            self.verbosity,
177
            self.objective_function,
178
            self.best_score,
179
            self.best_para,
180
            self.eval_times,
181
            self.iter_times,
182
            self.n_iter,
183
            self.random_seed,
184
        )
185
186
    def search_step(self, nth_iter):
187
        self.nth_iter = nth_iter
188
189
        if self.nth_iter < self.n_inits_norm:
190
            self._initialization()
191
192
        if self.nth_iter == self.n_init_search:
193
            self.finish_initialization()
194
195
        if self.n_init_search <= self.nth_iter < self.n_iter:
196
            self._iteration()
197
198
        """
199
        # loop to initialize N positions
200
        for nth_iter in range(self.n_inits_norm):
201
            if self.stop.check(self.start_time, self.p_bar.score_best, self.score_l):
202
                break
203
            self._initialization(nth_iter)
204
205
        self.finish_initialization()
206
207
        # loop to do the iterations
208
        for nth_iter in range(self.n_init_search, self.n_iter):
209
            if self.stop.check(self.start_time, self.p_bar.score_best, self.score_l):
210
                break
211
            self._iteration(nth_iter)
212
        """
213