Passed
Push — master ( b8f359...927840 )
by Simon
03:10
created

hyperactive.long_term_memory.long_term_memory   A

Complexity

Total Complexity 24

Size/Duplication

Total Lines 130
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 84
dl 0
loc 130
rs 10
c 0
b 0
f 0
wmc 24

10 Methods

Rating   Name   Duplication   Size   Complexity  
A LongTermMemory.load_search_data() 0 2 1
A LongTermMemory.__init__() 0 23 4
A LongTermMemory.save() 0 20 2
A LongTermMemory.load() 0 12 1
A LongTermMemory._get_old_samples_size() 0 3 2
A LongTermMemory.load_obj_func() 0 2 1
A LongTermMemory._pkl_valid() 0 2 1
A LongTermMemory.remove_model_data() 0 5 2
A LongTermMemory._dill_load() 0 6 3
A LongTermMemory._dill_dump() 0 3 2

3 Functions

Rating   Name   Duplication   Size   Complexity  
A drop_duplicates_unhashable_df() 0 2 1
A merge_unhashable_df() 0 18 3
A meta_data_path() 0 3 1
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import os
6
import dill
7
import shutil
8
import pandas as pd
9
10
from .ltm_data_path import ltm_data_path
11
12
13
def merge_unhashable_df(df1, df2):
14
    columns = df1.columns
15
    columns0 = df2.columns
16
17
    if set(columns) != set(columns0):
18
        print("Error columns of df1 and df2 must be the same")
19
        return
20
21
    df1["item"] = range(len(df1))
22
    df2["item"] = range(len(df1), len(df1) + len(df2))
23
24
    result = pd.merge(df1, df2, on="item", how="outer", suffixes=["", "_y"])
25
26
    for col in columns:
27
        result[col].update(result[col + "_y"])
28
29
    result = result[result.columns[~result.columns.str.endswith("_y")]]
30
    return result.drop(["item"], axis=1)
31
32
33
def drop_duplicates_unhashable_df(df):
34
    return df.loc[df.astype(str).drop_duplicates().index]
35
36
37
def meta_data_path():
38
    current_path = os.path.realpath(__file__)
39
    return current_path.rsplit("/", 1)[0] + "/"
40
41
42
class LongTermMemory:
43
    def __init__(self, model_name, study_name=None, path=None, verbosity=None):
44
        if study_name is None:
45
            study_name = "default"
46
47
        model_study_name = model_name + ":" + study_name
48
49
        if path is None:
50
            self.ltm_data_dir = ltm_data_path()
51
        else:
52
            self.ltm_data_dir = path + "/ltm_data/"
53
54
        self.model_dir = self.ltm_data_dir + model_study_name + "/"
55
56
        if not os.path.exists(self.model_dir):
57
            os.makedirs(self.model_dir)
58
59
        # print("\n model_dir \n", self.model_dir)
60
61
        self.search_data_path = self.model_dir + "search_data.pkl"
62
        self.obj_func_path = self.model_dir + "objective_function.pkl"
63
64
        self.n_old_samples = 0
65
        self.n_new_samples = 0
66
67
    def remove_model_data(self):
68
        try:
69
            shutil.rmtree(self.model_dir)
70
        except OSError:
71
            pass
72
73
    def _dill_dump(self, object_, path):
74
        with open(path, "wb") as handle:
75
            dill.dump(object_, handle)
76
77
    def _dill_load(self, path):
78
        if self._pkl_valid(path):
79
            with open(path, "rb") as handle:
80
                object_ = dill.load(handle)
81
82
            return object_
83
84
    def _pkl_valid(self, pkl_path):
85
        return os.path.isfile(pkl_path) and os.path.getsize(pkl_path) > 0
86
87
    def _get_old_samples_size(self, df):
88
        if isinstance(df, pd.DataFrame):
89
            self.n_old_samples = len(df)
90
91
    def load_obj_func(self):
92
        return self._dill_load(self.obj_func_path)
93
94
    def load_search_data(self):
95
        return self._dill_load(self.search_data_path)
96
97
    def load(self):
98
        print("Reading in long term memory ...", end="\r")
99
        self.results_old = self._dill_load(self.search_data_path)
100
        self._get_old_samples_size(self.results_old)
101
102
        print(
103
            "Reading long term memory was successful:",
104
            self.n_old_samples,
105
            "samples found",
106
        )
107
108
        return self.results_old
109
110
    def save(self, dataframe, objective_function):
111
        self.results_old = self._dill_load(self.search_data_path)
112
113
        if self.results_old is not None:
114
            self.n_old_samples = len(self.results_old)
115
116
            dataframe = merge_unhashable_df(dataframe, self.results_old)
117
            dataframe = drop_duplicates_unhashable_df(dataframe).reset_index(drop=True)
118
119
        self.n_new_samples = len(dataframe)
120
121
        self._dill_dump(objective_function, self.obj_func_path)
122
        print("Saving long term memory ...", end="\r")
123
124
        self._dill_dump(dataframe, self.search_data_path)
125
126
        print(
127
            "Saving long term memory was successful:",
128
            self.n_new_samples - self.n_old_samples,
129
            "new samples found",
130
        )
131