Passed
Push — master ( f8108f...3d944f )
by Simon
01:32
created

gradient_free_optimizers.search   A

Complexity

Total Complexity 21

Size/Duplication

Total Lines 208
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 21
eloc 147
dl 0
loc 208
rs 10
c 0
b 0
f 0

9 Methods

Rating   Name   Duplication   Size   Complexity  
B Search.init_search() 0 53 6
A Search.__init__() 0 12 1
A Search.finish_search() 0 23 1
B Search.search() 0 41 5
A Search._score() 0 3 1
A Search._evaluate_position() 0 5 1
A Search.search_step() 0 11 4
A Search._iteration() 0 16 1
A Search._initialization() 0 16 1
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
7
from ._progress_bar import ProgressBarLVL0, ProgressBarLVL1
8
from ._times_tracker import TimesTracker
9
from ._search_statistics import SearchStatistics
10
from ._print_info import print_info
11
from ._stop_run import StopRun
12
from ._results_manager import ResultsManager
13
from ._objective_adapter import ObjectiveAdapter
14
from ._memory import CachedObjectiveAdapter
15
from ._stopping_conditions import OptimizationStopper
16
17
18
class Search(TimesTracker, SearchStatistics):
19
    def __init__(self):
20
        super().__init__()
21
22
        self.optimizers = []
23
        self.new_results_list = []
24
        self.all_results_list = []
25
26
        self.score_l = []
27
        self.pos_l = []
28
        self.random_seed = None
29
30
        self.results_manager = ResultsManager()
31
32
    @TimesTracker.eval_time
33
    def _score(self, pos):
34
        return self.score(pos)
35
36
    @TimesTracker.iter_time
37
    def _initialization(self):
38
        self.best_score = self.p_bar.score_best
39
40
        init_pos = self.init_pos()
41
42
        score_new = self._evaluate_position(init_pos)
43
        self.evaluate_init(score_new)
44
45
        self.pos_l.append(init_pos)
46
        self.score_l.append(score_new)
47
48
        self.p_bar.update(score_new, init_pos, self.nth_iter)
49
50
        self.n_init_total += 1
51
        self.n_init_search += 1
52
53
    @TimesTracker.iter_time
54
    def _iteration(self):
55
        self.best_score = self.p_bar.score_best
56
57
        pos_new = self.iterate()
58
59
        score_new = self._evaluate_position(pos_new)
60
        self.evaluate(score_new)
61
62
        self.pos_l.append(pos_new)
63
        self.score_l.append(score_new)
64
65
        self.p_bar.update(score_new, pos_new, self.nth_iter)
66
67
        self.n_iter_total += 1
68
        self.n_iter_search += 1
69
70
    def search(
71
        self,
72
        objective_function,
73
        n_iter,
74
        max_time=None,
75
        max_score=None,
76
        early_stopping=None,
77
        memory=True,
78
        memory_warm_start=None,
79
        verbosity=["progress_bar", "print_results", "print_times"],
80
        optimum="maximum",
81
    ):
82
        self.optimum = optimum
83
        self.init_search(
84
            objective_function,
85
            n_iter,
86
            max_time,
87
            max_score,
88
            early_stopping,
89
            memory,
90
            memory_warm_start,
91
            verbosity,
92
        )
93
94
        for nth_trial in range(n_iter):
95
            self.search_step(nth_trial)
96
97
            # Update stopper with current state
98
            current_score = self.score_l[-1] if self.score_l else -np.inf
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable np does not seem to be defined.
Loading history...
99
            best_score = self.p_bar.score_best
100
            self.stopper.update(current_score, best_score, nth_trial)
101
102
            if self.stopper.should_stop():
103
                # Log debugging information when stopping
104
                if "debug_stop" in self.verbosity:
105
                    debug_info = self.stopper.get_debug_info()
106
                    print("\nStopping condition debug info:")
107
                    print(json.dumps(debug_info, indent=2))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable json does not seem to be defined.
Loading history...
108
                break
109
110
        self.finish_search()
111
112
    def _evaluate_position(self, pos: list[int]) -> float:
113
        result, params = self.adapter(pos)
114
        self.results_manager.add(result, params)
115
        self._iter += 1
116
        return result.score
117
118
    @SearchStatistics.init_stats
119
    def init_search(
120
        self,
121
        objective_function,
122
        n_iter,
123
        max_time,
124
        max_score,
125
        early_stopping,
126
        memory,
127
        memory_warm_start,
128
        verbosity,
129
    ):
130
        if getattr(self, "optimum", "maximum") == "minimum":
131
            self.objective_function = lambda pos: -objective_function(pos)
132
        else:
133
            self.objective_function = objective_function
134
        self.n_iter = n_iter
135
        self.max_time = max_time
136
        self.max_score = max_score
137
        self.early_stopping = early_stopping
138
        self.memory = memory
139
        self.memory_warm_start = memory_warm_start
140
        self.verbosity = verbosity
141
142
        self._iter = 0
143
144
        if self.verbosity is False:
145
            self.verbosity = []
146
147
        start_time = time.time()
148
        self.stopper = OptimizationStopper(
149
            start_time=start_time,
150
            max_time=max_time,
151
            max_score=max_score,
152
            early_stopping=early_stopping,
153
        )
154
155
        if "progress_bar" in self.verbosity:
156
            self.p_bar = ProgressBarLVL1(
157
                self.nth_process, self.n_iter, self.objective_function
158
            )
159
        else:
160
            self.p_bar = ProgressBarLVL0(
161
                self.nth_process, self.n_iter, self.objective_function
162
            )
163
164
        if self.memory not in [False, None]:
165
            self.adapter = CachedObjectiveAdapter(self.conv, objective_function)
166
            self.adapter.memory(memory_warm_start, memory)
167
        else:
168
            self.adapter = ObjectiveAdapter(self.conv, objective_function)
169
170
        self.n_inits_norm = min((self.init.n_inits - self.n_init_total), self.n_iter)
171
172
    def finish_search(self):
173
        self.search_data = self.results_manager.dataframe
174
175
        self.best_score = self.p_bar.score_best
176
        self.best_value = self.conv.position2value(self.p_bar.pos_best)
177
        self.best_para = self.conv.value2para(self.best_value)
178
        """
179
        if self.memory not in [False, None]:
180
            self.memory_dict = self.mem.memory_dict
181
        else:
182
            self.memory_dict = {}
183
        """
184
        self.p_bar.close()
185
186
        print_info(
187
            self.verbosity,
188
            self.objective_function,
189
            self.best_score,
190
            self.best_para,
191
            self.eval_times,
192
            self.iter_times,
193
            self.n_iter,
194
            self.random_seed,
195
        )
196
197
    def search_step(self, nth_iter):
198
        self.nth_iter = nth_iter
199
200
        if self.nth_iter < self.n_inits_norm:
201
            self._initialization()
202
203
        if self.nth_iter == self.n_init_search:
204
            self.finish_initialization()
205
206
        if self.n_init_search <= self.nth_iter < self.n_iter:
207
            self._iteration()
208