1
|
|
|
# Author: Simon Blanke |
2
|
|
|
# Email: [email protected] |
3
|
|
|
# License: MIT License |
4
|
|
|
|
5
|
|
|
import os |
6
|
|
|
from hyperactive_long_term_memory import LongTermMemory as _LongTermMemory_ |
7
|
|
|
|
8
|
|
|
|
9
|
|
|
class LongTermMemory: |
10
|
|
|
def __init__( |
11
|
|
|
self, |
12
|
|
|
study_id, |
13
|
|
|
model_id, |
14
|
|
|
save_on="finish", |
15
|
|
|
): |
16
|
|
|
self.study_id = study_id |
17
|
|
|
self.model_id = model_id |
18
|
|
|
|
19
|
|
|
path = os.path.realpath(__file__).rsplit("/", 1)[0] + "/" |
20
|
|
|
|
21
|
|
|
self.ltm_origin = _LongTermMemory_(path=".") |
22
|
|
|
self.save_on = save_on |
23
|
|
|
|
24
|
|
|
if save_on == "finish": |
25
|
|
|
self.ltm_obj_func_wrapper = self._no_ltm_wrapper |
26
|
|
|
elif save_on == "iteration": |
27
|
|
|
self.ltm_obj_func_wrapper = self._ltm_wrapper |
28
|
|
|
|
29
|
|
|
def _no_ltm_wrapper(self, results, para): |
30
|
|
|
pass |
31
|
|
|
|
32
|
|
|
def _ltm_wrapper(self, results, para): |
33
|
|
|
if isinstance(results, tuple): |
34
|
|
|
score = results[0] |
35
|
|
|
results_dict = results[1] |
36
|
|
|
else: |
37
|
|
|
score = results |
38
|
|
|
results_dict = {} |
39
|
|
|
|
40
|
|
|
results_dict["score"] = score |
41
|
|
|
ltm_dict = {**para, **results_dict} |
42
|
|
|
self.save_on_iteration(ltm_dict, self.nth_process) |
43
|
|
|
|
44
|
|
|
def init_study(self, objective_function, search_space, nth_process): |
45
|
|
|
self.nth_process = nth_process |
46
|
|
|
self.ltm_origin.init_study( |
47
|
|
|
objective_function, |
48
|
|
|
search_space, |
49
|
|
|
study_id=self.study_id, |
50
|
|
|
model_id=self.model_id, |
51
|
|
|
) |
52
|
|
|
|
53
|
|
|
def load(self): |
54
|
|
|
return self.ltm_origin.load() |
55
|
|
|
|
56
|
|
|
def save_on_finish(self, dataframe): |
57
|
|
|
self.ltm_origin.save_on_finish(dataframe) |
58
|
|
|
|
59
|
|
|
def save_on_iteration(self, data_dict, nth_process): |
60
|
|
|
self.ltm_origin.save_on_iteration(data_dict, nth_process) |
61
|
|
|
|