|
1
|
|
|
# Author: Simon Blanke |
|
2
|
|
|
# Email: [email protected] |
|
3
|
|
|
# License: MIT License |
|
4
|
|
|
|
|
5
|
|
|
import os |
|
6
|
|
|
from hyperactive_long_term_memory import LongTermMemory as _LongTermMemory_ |
|
7
|
|
|
|
|
8
|
|
|
|
|
9
|
|
|
class LongTermMemory: |
|
10
|
|
|
def __init__( |
|
11
|
|
|
self, |
|
12
|
|
|
study_id, |
|
13
|
|
|
model_id, |
|
14
|
|
|
save_on="finish", |
|
15
|
|
|
): |
|
16
|
|
|
self.study_id = study_id |
|
17
|
|
|
self.model_id = model_id |
|
18
|
|
|
|
|
19
|
|
|
path = os.path.realpath(__file__).rsplit("/", 1)[0] + "/" |
|
20
|
|
|
|
|
21
|
|
|
self.ltm_origin = _LongTermMemory_(path=".") |
|
22
|
|
|
self.save_on = save_on |
|
23
|
|
|
|
|
24
|
|
|
if save_on == "finish": |
|
25
|
|
|
self.ltm_obj_func_wrapper = self._no_ltm_wrapper |
|
26
|
|
|
elif save_on == "iteration": |
|
27
|
|
|
self.ltm_obj_func_wrapper = self._ltm_wrapper |
|
28
|
|
|
|
|
29
|
|
|
def _no_ltm_wrapper(self, results, para): |
|
30
|
|
|
pass |
|
31
|
|
|
|
|
32
|
|
|
def _ltm_wrapper(self, results, para): |
|
33
|
|
|
if isinstance(results, tuple): |
|
34
|
|
|
score = results[0] |
|
35
|
|
|
results_dict = results[1] |
|
36
|
|
|
else: |
|
37
|
|
|
score = results |
|
38
|
|
|
results_dict = {} |
|
39
|
|
|
|
|
40
|
|
|
results_dict["score"] = score |
|
41
|
|
|
ltm_dict = {**para, **results_dict} |
|
42
|
|
|
self.save_on_iteration(ltm_dict, self.nth_process) |
|
43
|
|
|
|
|
44
|
|
|
def init_study(self, objective_function, search_space, nth_process): |
|
45
|
|
|
self.nth_process = nth_process |
|
46
|
|
|
self.ltm_origin.init_study( |
|
47
|
|
|
objective_function, |
|
48
|
|
|
search_space, |
|
49
|
|
|
study_id=self.study_id, |
|
50
|
|
|
model_id=self.model_id, |
|
51
|
|
|
) |
|
52
|
|
|
|
|
53
|
|
|
def load(self): |
|
54
|
|
|
return self.ltm_origin.load() |
|
55
|
|
|
|
|
56
|
|
|
def save_on_finish(self, dataframe): |
|
57
|
|
|
self.ltm_origin.save_on_finish(dataframe) |
|
58
|
|
|
|
|
59
|
|
|
def save_on_iteration(self, data_dict, nth_process): |
|
60
|
|
|
self.ltm_origin.save_on_iteration(data_dict, nth_process) |
|
61
|
|
|
|