Passed
Push — master ( 9ff666...1a4396 )
by Simon
03:24
created

ShortTermMemory.__init__()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import os
6
import glob
7
import hashlib
8
import inspect
9
10
import numpy as np
11
import pandas as pd
12
13
14
class Memory:
15
    def __init__(self, _space_, _main_args_):
16
        self._space_ = _space_
17
        self._main_args_ = _main_args_
18
19
        self.pos_best = None
20
        self.score_best = -np.inf
21
22
        self.memory_type = _main_args_.memory
23
        self.memory_dict = {}
24
25
        self.meta_data_found = False
26
27
    def load_memory(self, model_func):
28
        pass
29
30
    def save_memory(self, _main_args_, _cand_):
31
        pass
32
33
34
class ShortTermMemory(Memory):
35
    def __init__(self, _space_, _main_args_):
36
        super().__init__(_space_, _main_args_)
37
38
39
class LongTermMemory(Memory):
40
    def __init__(self, _space_, _main_args_):
41
        super().__init__(_space_, _main_args_)
42
43
        self.score_col_name = "mean_test_score"
44
45
        current_path = os.path.realpath(__file__)
46
        meta_learn_path, _ = current_path.rsplit("/", 1)
47
        self.meta_data_path = meta_learn_path + "/meta_data/"
48
49
    def load_memory(self, model_func):
50
        para, score = self._read_func_metadata(model_func)
51
        if para is None or score is None:
52
            return
53
54
        self._load_data_into_memory(para, score)
55
56
    def save_memory(self, _main_args_, _cand_):
57
        meta_data = self._collect()
58
        path = self._get_file_path(_cand_.func_)
59
        self._save_toCSV(meta_data, path)
60
61
    def _save_toCSV(self, meta_data_new, path):
62
        if os.path.exists(path):
63
            meta_data_old = pd.read_csv(path)
64
            meta_data = meta_data_old.append(meta_data_new)
65
66
            columns = list(meta_data.columns)
67
            noScore = ["mean_test_score", "cv_default_score"]
68
            columns_noScore = [c for c in columns if c not in noScore]
69
70
            meta_data = meta_data.drop_duplicates(subset=columns_noScore)
71
        else:
72
            meta_data = meta_data_new
73
74
        meta_data.to_csv(path, index=False)
75
76
    def _read_func_metadata(self, model_func):
77
        paths = glob.glob(self._get_func_file_paths(model_func))
78
79
        meta_data_list = []
80
        for path in paths:
81
            meta_data = pd.read_csv(path)
82
            meta_data_list.append(meta_data)
83
            self.meta_data_found = True
84
85
        if len(meta_data_list) > 0:
86
            meta_data = pd.concat(meta_data_list, ignore_index=True)
87
88
            column_names = meta_data.columns
89
            score_name = [name for name in column_names if self.score_col_name in name]
90
91
            para = meta_data.drop(score_name, axis=1)
92
            score = meta_data[score_name]
93
94
            return para, score
95
96
        else:
97
            print("Warning: No meta data found for following function:", model_func)
98
            return None, None
99
100
    def _get_opt_meta_data(self):
101
        results_dict = {}
102
        para_list = []
103
        score_list = []
104
105
        for key in self.memory_dict.keys():
106
            pos = np.fromstring(key, dtype=float)
107
            para = self._space_.pos2para(pos)
108
            score = self.memory_dict[key]
109
110
            if score != 0:
111
                para_list.append(para)
112
                score_list.append(score)
113
114
        results_dict["params"] = para_list
115
        results_dict["mean_test_score"] = score_list
116
117
        return results_dict
118
119
    def _load_data_into_memory(self, paras, scores):
120
        for idx in range(paras.shape[0]):
121
            pos = self._space_.para2pos(paras.iloc[[idx]])
122
            pos_str = pos.tostring()
123
124
            score = float(scores.values[idx])
125
            self.memory_dict[pos_str] = score
126
127
            if score > self.score_best:
128
                self.score_best = score
129
                self.pos_best = pos
130
131
    def _get_para(self):
132
        results_dict = self._get_opt_meta_data()
133
134
        return pd.DataFrame(results_dict["params"])
135
136
    def _get_score(self):
137
        results_dict = self._get_opt_meta_data()
138
        return pd.DataFrame(
139
            results_dict["mean_test_score"], columns=["mean_test_score"]
140
        )
141
142
    def _collect(self):
143
        para_pd = self._get_para()
144
        md_model = para_pd.reindex(sorted(para_pd.columns), axis=1)
145
        metric_pd = self._get_score()
146
147
        md_model = pd.concat([para_pd, metric_pd], axis=1, ignore_index=False)
148
149
        return md_model
150
151
    def _get_hash(self, object):
152
        return hashlib.sha1(object).hexdigest()
153
154
    def _get_func_str(self, func):
155
        return inspect.getsource(func)
156
157
    def _get_func_file_paths(self, model_func):
158
        func_str = self._get_func_str(model_func)
159
        self.func_path = self._get_hash(func_str.encode("utf-8")) + "/"
160
161
        directory = self.meta_data_path + self.func_path
162
        if not os.path.exists(directory):
163
            os.makedirs(directory, exist_ok=True)
164
165
        return directory + ("metadata" + "*" + "__.csv")
166
167
    def _get_file_path(self, model_func):
168
        func_str = self._get_func_str(model_func)
169
        feature_hash = self._get_hash(self._main_args_.X)
170
        label_hash = self._get_hash(self._main_args_.y)
171
172
        self.func_path = self._get_hash(func_str.encode("utf-8")) + "/"
173
174
        directory = self.meta_data_path + self.func_path
175
        if not os.path.exists(directory):
176
            os.makedirs(directory)
177
178
        return directory + (
179
            "metadata"
180
            + "__feature_hash="
181
            + feature_hash
182
            + "__label_hash="
183
            + label_hash
184
            + "__.csv"
185
        )
186