gradient_free_optimizers.search   A
last analyzed

Complexity

Total Complexity 19

Size/Duplication

Total Lines 195
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 19
eloc 137
dl 0
loc 195
rs 10
c 0
b 0
f 0

8 Methods

Rating   Name   Duplication   Size   Complexity  
A Search.__init__() 0 12 1
A Search._score() 0 3 1
A Search._iteration() 0 18 1
A Search._initialization() 0 18 1
B Search.init_search() 0 54 6
A Search.finish_search() 0 23 2
A Search.search() 0 30 3
A Search.search_step() 0 11 4
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
7
from ._progress_bar import ProgressBarLVL0, ProgressBarLVL1
8
from ._times_tracker import TimesTracker
9
from ._search_statistics import SearchStatistics
10
from ._memory import Memory
11
from ._print_info import print_info
12
from ._stop_run import StopRun
13
from ._results_manager import ResultsManager
14
15
16
class Search(TimesTracker, SearchStatistics):
17
    def __init__(self):
18
        super().__init__()
19
20
        self.optimizers = []
21
        self.new_results_list = []
22
        self.all_results_list = []
23
24
        self.score_l = []
25
        self.pos_l = []
26
        self.random_seed = None
27
28
        self.results_mang = ResultsManager()
29
30
    @TimesTracker.eval_time
31
    def _score(self, pos):
32
        return self.score(pos)
33
34
    @TimesTracker.iter_time
35
    def _initialization(self):
36
        self.best_score = self.p_bar.score_best
37
38
        init_pos = self.init_pos()
39
40
        score_new = self._score(init_pos)
41
        self.evaluate_init(score_new)
42
43
        self.pos_l.append(init_pos)
44
        self.score_l.append(score_new)
45
46
        self.p_bar.update(score_new, init_pos, self.nth_iter)
47
48
        self.n_init_total += 1
49
        self.n_init_search += 1
50
51
        self.stop.update(self.p_bar.score_best, self.score_l)
52
53
    @TimesTracker.iter_time
54
    def _iteration(self):
55
        self.best_score = self.p_bar.score_best
56
57
        pos_new = self.iterate()
58
59
        score_new = self._score(pos_new)
60
        self.evaluate(score_new)
61
62
        self.pos_l.append(pos_new)
63
        self.score_l.append(score_new)
64
65
        self.p_bar.update(score_new, pos_new, self.nth_iter)
66
67
        self.n_iter_total += 1
68
        self.n_iter_search += 1
69
70
        self.stop.update(self.p_bar.score_best, self.score_l)
71
72
    def search(
73
        self,
74
        objective_function,
75
        n_iter,
76
        max_time=None,
77
        max_score=None,
78
        early_stopping=None,
79
        memory=True,
80
        memory_warm_start=None,
81
        verbosity=["progress_bar", "print_results", "print_times"],
82
        optimum = "maximum",
83
    ):
84
        self.optimum = optimum
85
        self.init_search(
86
            objective_function,
87
            n_iter,
88
            max_time,
89
            max_score,
90
            early_stopping,
91
            memory,
92
            memory_warm_start,
93
            verbosity,
94
        )
95
96
        for nth_trial in range(n_iter):
97
            self.search_step(nth_trial)
98
            if self.stop.check():
99
                break
100
101
        self.finish_search()
102
103
    @SearchStatistics.init_stats
104
    def init_search(
105
        self,
106
        objective_function,
107
        n_iter,
108
        max_time,
109
        max_score,
110
        early_stopping,
111
        memory,
112
        memory_warm_start,
113
        verbosity,
114
    ):
115
        if getattr(self, "optimum", "maximum") == "minimum":
116
            self.objective_function = lambda pos: -objective_function(pos)
117
        else:
118
            self.objective_function = objective_function
119
        self.n_iter = n_iter
120
        self.max_time = max_time
121
        self.max_score = max_score
122
        self.early_stopping = early_stopping
123
        self.memory = memory
124
        self.memory_warm_start = memory_warm_start
125
        self.verbosity = verbosity
126
127
        self.results_mang.conv = self.conv
128
129
        if self.verbosity is False:
130
            self.verbosity = []
131
132
        start_time = time.time()
133
        self.stop = StopRun(
134
            start_time, self.max_time, self.max_score, self.early_stopping
135
        )
136
137
        if "progress_bar" in self.verbosity:
138
            self.p_bar = ProgressBarLVL1(
139
                self.nth_process, self.n_iter, self.objective_function
140
            )
141
        else:
142
            self.p_bar = ProgressBarLVL0(
143
                self.nth_process, self.n_iter, self.objective_function
144
            )
145
146
        self.mem = Memory(self.memory_warm_start, self.conv, memory=self.memory)
147
148
        if self.memory not in [False, None]:
149
            self.score = self.results_mang.score(
150
                self.mem.memory(self.objective_function)
151
            )
152
        else:
153
            self.score = self.results_mang.score(self.objective_function)
154
155
        self.n_inits_norm = min(
156
            (self.init.n_inits - self.n_init_total), self.n_iter
157
        )
158
159
    def finish_search(self):
160
        self.search_data = self.results_mang.search_data
161
162
        self.best_score = self.p_bar.score_best
163
        self.best_value = self.conv.position2value(self.p_bar.pos_best)
164
        self.best_para = self.conv.value2para(self.best_value)
165
166
        if self.memory not in [False, None]:
167
            self.memory_dict = self.mem.memory_dict
168
        else:
169
            self.memory_dict = {}
170
171
        self.p_bar.close()
172
173
        print_info(
174
            self.verbosity,
175
            self.objective_function,
176
            self.best_score,
177
            self.best_para,
178
            self.eval_times,
179
            self.iter_times,
180
            self.n_iter,
181
            self.random_seed,
182
        )
183
184
    def search_step(self, nth_iter):
185
        self.nth_iter = nth_iter
186
187
        if self.nth_iter < self.n_inits_norm:
188
            self._initialization()
189
190
        if self.nth_iter == self.n_init_search:
191
            self.finish_initialization()
192
193
        if self.n_init_search <= self.nth_iter < self.n_iter:
194
            self._iteration()
195