1
|
|
|
# Author: Simon Blanke |
2
|
|
|
# Email: [email protected] |
3
|
|
|
# License: MIT License |
4
|
|
|
|
5
|
|
|
import os |
6
|
|
|
from hyperactive_long_term_memory import LongTermMemory as _LongTermMemory_ |
7
|
|
|
|
8
|
|
|
|
9
|
|
|
class LongTermMemory: |
10
|
|
|
def __init__( |
11
|
|
|
self, |
12
|
|
|
model_id, |
13
|
|
|
experiment_id="default", |
14
|
|
|
save_on="finish", |
15
|
|
|
): |
16
|
|
|
|
17
|
|
|
path, _ = os.path.realpath(__file__).rsplit("/", 1) |
18
|
|
|
path = path + "/" |
19
|
|
|
|
20
|
|
|
self.ltm_origin = _LongTermMemory_( |
21
|
|
|
model_id=model_id, experiment_id=experiment_id, path="." |
22
|
|
|
) |
23
|
|
|
self.save_on = save_on |
24
|
|
|
|
25
|
|
|
if save_on == "finish": |
26
|
|
|
self.ltm_obj_func_wrapper = self._no_ltm_wrapper |
27
|
|
|
elif save_on == "iteration": |
28
|
|
|
self.ltm_obj_func_wrapper = self._ltm_wrapper |
29
|
|
|
|
30
|
|
|
def _no_ltm_wrapper(self, results, para): |
31
|
|
|
pass |
32
|
|
|
|
33
|
|
|
def _ltm_wrapper(self, results, para): |
34
|
|
|
if isinstance(results, tuple): |
35
|
|
|
score = results[0] |
36
|
|
|
results_dict = results[1] |
37
|
|
|
else: |
38
|
|
|
score = results |
39
|
|
|
results_dict = {} |
40
|
|
|
|
41
|
|
|
results_dict["score"] = score |
42
|
|
|
ltm_dict = {**para, **results_dict} |
43
|
|
|
self.save_on_iteration(ltm_dict) |
44
|
|
|
|
45
|
|
|
def clean_files(self): |
46
|
|
|
self.ltm_origin.clean_files() |
47
|
|
|
|
48
|
|
|
def init_data_types(self, search_space): |
49
|
|
|
self.ltm_origin.init_data_types(search_space) |
50
|
|
|
|
51
|
|
|
def load(self): |
52
|
|
|
return self.ltm_origin.load() |
53
|
|
|
|
54
|
|
|
def save_on_finish(self, dataframe): |
55
|
|
|
self.ltm_origin.save_on_finish(dataframe) |
56
|
|
|
|
57
|
|
|
def save_on_iteration(self, data_dict): |
58
|
|
|
self.ltm_origin.save_on_iteration(data_dict) |
59
|
|
|
|