1
|
|
|
# Author: Simon Blanke |
2
|
|
|
# Email: [email protected] |
3
|
|
|
# License: MIT License |
4
|
|
|
|
5
|
|
|
import os |
6
|
|
|
import dill |
7
|
|
|
import shutil |
8
|
|
|
import pandas as pd |
9
|
|
|
|
10
|
|
|
from .ltm_data_path import ltm_data_path |
11
|
|
|
|
12
|
|
|
|
13
|
|
|
def merge_unhashable_df(df1, df2): |
14
|
|
|
columns = df1.columns |
15
|
|
|
columns0 = df2.columns |
16
|
|
|
|
17
|
|
|
if set(columns) != set(columns0): |
18
|
|
|
print("Error columns of df1 and df2 must be the same") |
19
|
|
|
return |
20
|
|
|
|
21
|
|
|
df1["item"] = range(len(df1)) |
22
|
|
|
df2["item"] = range(len(df1), len(df1) + len(df2)) |
23
|
|
|
|
24
|
|
|
result = pd.merge(df1, df2, on="item", how="outer", suffixes=["", "_y"]) |
25
|
|
|
|
26
|
|
|
for col in columns: |
27
|
|
|
result[col].update(result[col + "_y"]) |
28
|
|
|
|
29
|
|
|
result = result[result.columns[~result.columns.str.endswith("_y")]] |
30
|
|
|
return result.drop(["item"], axis=1) |
31
|
|
|
|
32
|
|
|
|
33
|
|
|
def drop_duplicates_unhashable_df(df): |
34
|
|
|
return df.loc[df.astype(str).drop_duplicates().index] |
35
|
|
|
|
36
|
|
|
|
37
|
|
|
def meta_data_path(): |
38
|
|
|
current_path = os.path.realpath(__file__) |
39
|
|
|
return current_path.rsplit("/", 1)[0] + "/" |
40
|
|
|
|
41
|
|
|
|
42
|
|
|
class LongTermMemory: |
43
|
|
|
def __init__(self, model_name, study_name=None, path=None, verbosity=None): |
44
|
|
|
if study_name is None: |
45
|
|
|
study_name = "default" |
46
|
|
|
|
47
|
|
|
model_study_name = model_name + ":" + study_name |
48
|
|
|
|
49
|
|
|
if path is None: |
50
|
|
|
self.ltm_data_dir = ltm_data_path() |
51
|
|
|
else: |
52
|
|
|
self.ltm_data_dir = path + "/ltm_data/" |
53
|
|
|
|
54
|
|
|
self.model_dir = self.ltm_data_dir + model_study_name + "/" |
55
|
|
|
|
56
|
|
|
if not os.path.exists(self.model_dir): |
57
|
|
|
os.makedirs(self.model_dir) |
58
|
|
|
|
59
|
|
|
# print("\n model_dir \n", self.model_dir) |
60
|
|
|
|
61
|
|
|
self.search_data_path = self.model_dir + "search_data.pkl" |
62
|
|
|
self.obj_func_path = self.model_dir + "objective_function.pkl" |
63
|
|
|
|
64
|
|
|
self.n_old_samples = 0 |
65
|
|
|
self.n_new_samples = 0 |
66
|
|
|
|
67
|
|
|
def remove_model_data(self): |
68
|
|
|
try: |
69
|
|
|
shutil.rmtree(self.model_dir) |
70
|
|
|
except OSError: |
71
|
|
|
pass |
72
|
|
|
|
73
|
|
|
def _dill_dump(self, object_, path): |
74
|
|
|
with open(path, "wb") as handle: |
75
|
|
|
dill.dump(object_, handle) |
76
|
|
|
|
77
|
|
|
def _dill_load(self, path): |
78
|
|
|
if self._pkl_valid(path): |
79
|
|
|
with open(path, "rb") as handle: |
80
|
|
|
object_ = dill.load(handle) |
81
|
|
|
|
82
|
|
|
return object_ |
83
|
|
|
|
84
|
|
|
def _pkl_valid(self, pkl_path): |
85
|
|
|
return os.path.isfile(pkl_path) and os.path.getsize(pkl_path) > 0 |
86
|
|
|
|
87
|
|
|
def _get_old_samples_size(self, df): |
88
|
|
|
if isinstance(df, pd.DataFrame): |
89
|
|
|
self.n_old_samples = len(df) |
90
|
|
|
|
91
|
|
|
def load_obj_func(self): |
92
|
|
|
return self._dill_load(self.obj_func_path) |
93
|
|
|
|
94
|
|
|
def load_search_data(self): |
95
|
|
|
return self._dill_load(self.search_data_path) |
96
|
|
|
|
97
|
|
|
def load(self): |
98
|
|
|
print("Reading in long term memory ...", end="\r") |
99
|
|
|
self.results_old = self._dill_load(self.search_data_path) |
100
|
|
|
self._get_old_samples_size(self.results_old) |
101
|
|
|
|
102
|
|
|
print( |
103
|
|
|
"Reading long term memory was successful:", |
104
|
|
|
self.n_old_samples, |
105
|
|
|
"samples found", |
106
|
|
|
) |
107
|
|
|
|
108
|
|
|
return self.results_old |
109
|
|
|
|
110
|
|
|
def save(self, dataframe, objective_function): |
111
|
|
|
self.results_old = self._dill_load(self.search_data_path) |
112
|
|
|
|
113
|
|
|
if self.results_old is not None: |
114
|
|
|
self.n_old_samples = len(self.results_old) |
115
|
|
|
|
116
|
|
|
dataframe = merge_unhashable_df(dataframe, self.results_old) |
117
|
|
|
dataframe = drop_duplicates_unhashable_df(dataframe).reset_index(drop=True) |
118
|
|
|
|
119
|
|
|
self.n_new_samples = len(dataframe) |
120
|
|
|
|
121
|
|
|
self._dill_dump(objective_function, self.obj_func_path) |
122
|
|
|
print("Saving long term memory ...", end="\r") |
123
|
|
|
|
124
|
|
|
self._dill_dump(dataframe, self.search_data_path) |
125
|
|
|
|
126
|
|
|
print( |
127
|
|
|
"Saving long term memory was successful:", |
128
|
|
|
self.n_new_samples - self.n_old_samples, |
129
|
|
|
"new samples found", |
130
|
|
|
) |
131
|
|
|
|