Passed
Push — master ( 566aaf...d7d703 )
by Simon
04:24
created

gradient_free_optimizers.search   A

Complexity

Total Complexity 17

Size/Duplication

Total Lines 189
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 132
dl 0
loc 189
rs 10
c 0
b 0
f 0
wmc 17

8 Methods

Rating   Name   Duplication   Size   Complexity  
A Search.init_search() 0 50 4
A Search.__init__() 0 12 1
A Search.finish_search() 0 23 2
A Search.search() 0 28 3
A Search._score() 0 3 1
A Search.search_step() 0 11 4
A Search._iteration() 0 18 1
A Search._initialization() 0 18 1
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
7
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
8
from .times_tracker import TimesTracker
9
from .search_statistics import SearchStatistics
10
from .memory import Memory
11
from .print_info import print_info
12
from .stop_run import StopRun
13
14
from .results_manager import ResultsManager
15
16
17
class Search(TimesTracker, SearchStatistics):
18
    def __init__(self):
19
        super().__init__()
20
21
        self.optimizers = []
22
        self.new_results_list = []
23
        self.all_results_list = []
24
25
        self.score_l = []
26
        self.pos_l = []
27
        self.random_seed = None
28
29
        self.results_mang = ResultsManager()
30
31
    @TimesTracker.eval_time
32
    def _score(self, pos):
33
        return self.score(pos)
34
35
    @TimesTracker.iter_time
36
    def _initialization(self):
37
        self.best_score = self.p_bar.score_best
38
39
        init_pos = self.init_pos()
40
41
        score_new = self._score(init_pos)
42
        self.evaluate_init(score_new)
43
44
        self.pos_l.append(init_pos)
45
        self.score_l.append(score_new)
46
47
        self.p_bar.update(score_new, init_pos, self.nth_iter)
48
49
        self.n_init_total += 1
50
        self.n_init_search += 1
51
52
        self.stop.update(self.p_bar.score_best, self.score_l)
53
54
    @TimesTracker.iter_time
55
    def _iteration(self):
56
        self.best_score = self.p_bar.score_best
57
58
        pos_new = self.iterate()
59
60
        score_new = self._score(pos_new)
61
        self.evaluate(score_new)
62
63
        self.pos_l.append(pos_new)
64
        self.score_l.append(score_new)
65
66
        self.p_bar.update(score_new, pos_new, self.nth_iter)
67
68
        self.n_iter_total += 1
69
        self.n_iter_search += 1
70
71
        self.stop.update(self.p_bar.score_best, self.score_l)
72
73
    def search(
74
        self,
75
        objective_function,
76
        n_iter,
77
        max_time=None,
78
        max_score=None,
79
        early_stopping=None,
80
        memory=True,
81
        memory_warm_start=None,
82
        verbosity=["progress_bar", "print_results", "print_times"],
83
    ):
84
        self.init_search(
85
            objective_function,
86
            n_iter,
87
            max_time,
88
            max_score,
89
            early_stopping,
90
            memory,
91
            memory_warm_start,
92
            verbosity,
93
        )
94
95
        for nth_trial in range(n_iter):
96
            self.search_step(nth_trial)
97
            if self.stop.check():
98
                break
99
100
        self.finish_search()
101
102
    @SearchStatistics.init_stats
103
    def init_search(
104
        self,
105
        objective_function,
106
        n_iter,
107
        max_time,
108
        max_score,
109
        early_stopping,
110
        memory,
111
        memory_warm_start,
112
        verbosity,
113
    ):
114
        self.objective_function = objective_function
115
        self.n_iter = n_iter
116
        self.max_time = max_time
117
        self.max_score = max_score
118
        self.early_stopping = early_stopping
119
        self.memory = memory
120
        self.memory_warm_start = memory_warm_start
121
        self.verbosity = verbosity
122
123
        self.results_mang.conv = self.conv
124
125
        if self.verbosity is False:
126
            self.verbosity = []
127
128
        start_time = time.time()
129
        self.stop = StopRun(
130
            start_time, self.max_time, self.max_score, self.early_stopping
131
        )
132
133
        if "progress_bar" in self.verbosity:
134
            self.p_bar = ProgressBarLVL1(
135
                self.nth_process, self.n_iter, self.objective_function
136
            )
137
        else:
138
            self.p_bar = ProgressBarLVL0(
139
                self.nth_process, self.n_iter, self.objective_function
140
            )
141
142
        self.mem = Memory(self.memory_warm_start, self.conv, memory=self.memory)
143
144
        if self.memory not in [False, None]:
145
            self.score = self.results_mang.score(
146
                self.mem.memory(self.objective_function)
147
            )
148
        else:
149
            self.score = self.results_mang.score(self.objective_function)
150
151
        self.n_inits_norm = min((self.init.n_inits - self.n_init_total), self.n_iter)
152
153
    def finish_search(self):
154
        self.search_data = self.results_mang.search_data
155
156
        self.best_score = self.p_bar.score_best
157
        self.best_value = self.conv.position2value(self.p_bar.pos_best)
158
        self.best_para = self.conv.value2para(self.best_value)
159
160
        if self.memory not in [False, None]:
161
            self.memory_dict = self.mem.memory_dict
162
        else:
163
            self.memory_dict = {}
164
165
        self.p_bar.close()
166
167
        print_info(
168
            self.verbosity,
169
            self.objective_function,
170
            self.best_score,
171
            self.best_para,
172
            self.eval_times,
173
            self.iter_times,
174
            self.n_iter,
175
            self.random_seed,
176
        )
177
178
    def search_step(self, nth_iter):
179
        self.nth_iter = nth_iter
180
181
        if self.nth_iter < self.n_inits_norm:
182
            self._initialization()
183
184
        if self.nth_iter == self.n_init_search:
185
            self.finish_initialization()
186
187
        if self.n_init_search <= self.nth_iter < self.n_iter:
188
            self._iteration()
189