Passed
Push — master ( 916d9f...163de3 )
by Simon
02:25
created

LongTermMemory._get_pkl_hash()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import os
6
import glob
7
import json
8
import dill
9
import datetime
10
import hashlib
11
import inspect
12
13
import numpy as np
14
import pandas as pd
15
16
17
class Memory:
18
    def __init__(self, _space_, _main_args_, _cand_):
19
        self._space_ = _space_
20
        self._main_args_ = _main_args_
21
22
        self.pos_best = None
23
        self.score_best = -np.inf
24
25
        self.memory_type = _main_args_.memory
26
        self.memory_dict = {}
27
28
        self.meta_data_found = False
29
30
31
class ShortTermMemory(Memory):
32
    def __init__(self, _space_, _main_args_, _cand_):
33
        super().__init__(_space_, _main_args_, _cand_)
34
35
36
class LongTermMemory(Memory):
37
    def __init__(self, _space_, _main_args_, _cand_):
38
        super().__init__(_space_, _main_args_, _cand_)
39
40
        self.score_col_name = "mean_test_score"
41
42
        self.feature_hash = self._get_hash(_main_args_.X)
43
        self.label_hash = self._get_hash(_main_args_.y)
44
45
        current_path = os.path.realpath(__file__)
46
        meta_learn_path, _ = current_path.rsplit("/", 1)
47
48
        self.datetime = datetime.datetime.now().strftime("%d.%m.%Y - %H:%M:%S")
49
        func_str = self._get_func_str(_cand_.func_)
50
        self.func_path_ = self._get_hash(func_str.encode("utf-8")) + "/"
51
52
        self.meta_path = meta_learn_path + "/meta_data/"
53
        self.func_path = self.meta_path + self.func_path_
54
        self.date_path = self.meta_path + self.func_path_ + self.datetime + "/"
55
56
        if not os.path.exists(self.date_path):
57
            os.makedirs(self.date_path, exist_ok=True)
58
59
    def load_memory(self, _cand_, _verb_):
60
        para, score = self._read_func_metadata(_cand_.func_, _verb_)
61
        if para is None or score is None:
62
            return
63
64
        _verb_.load_samples(para)
65
66
        _cand_.eval_time = list(para["eval_time"])
67
68
        self._load_data_into_memory(para, score)
69
70
    def save_memory(self, _main_args_, _opt_args_, _cand_):
71
        path = self._get_file_path(_cand_.func_)
72
        meta_data = self._collect(_cand_)
73
74
        meta_data["run"] = self.datetime
75
76
        self._save_toCSV(meta_data, path)
77
78
        obj_func_path = self.func_path + "objective_function.py"
79
        if not os.path.exists(obj_func_path):
80
            file = open(obj_func_path, "w")
81
            file.write(self._get_func_str(_cand_.func_))
82
            file.close()
83
84
        search_config_path = self.date_path + "search_config.py"
85
        search_config_temp = dict(self._main_args_.search_config)
86
87
        for key in search_config_temp.keys():
88
            if isinstance(key, str):
89
                continue
90
            search_config_temp[key.__name__] = search_config_temp[key]
91
            del search_config_temp[key]
92
93
        search_config_str = "search_config = " + str(search_config_temp)
94
95
        if not os.path.exists(search_config_path):
96
            file = open(search_config_path, "w")
97
            file.write(search_config_str)
98
            file.close()
99
100
        """
101
        os.chdir(self.date_path)
102
        os.system("black search_config.py")
103
        os.getcwd()
104
        """
105
106
        run_data = {
107
            "random_state": self._main_args_.random_state,
108
            "max_time": self._main_args_.random_state,
109
            "n_iter": self._main_args_.n_iter,
110
            "optimizer": self._main_args_.optimizer,
111
            "n_jobs": self._main_args_.n_jobs,
112
            "eval_time": np.array(_cand_.eval_time).sum(),
113
            "total_time": _cand_.total_time,
114
        }
115
116
        with open(self.date_path + "run_data.json", "w") as f:
117
            json.dump(run_data, f, indent=4)
118
119
        """
120
        print("_opt_args_.kwargs_opt", _opt_args_.kwargs_opt)
121
122
        opt_para = pd.DataFrame.from_dict(_opt_args_.kwargs_opt, dtype=object)
123
        print("opt_para", opt_para)
124
        opt_para.to_csv(self.date_path + "opt_para", index=False)
125
        """
126
127
    def _save_toCSV(self, meta_data_new, path):
128
        if os.path.exists(path):
129
            meta_data_old = pd.read_csv(path)
130
            meta_data = meta_data_old.append(meta_data_new)
131
132
            columns = list(meta_data.columns)
133
            noScore = ["mean_test_score", "cv_default_score", "eval_time", "run"]
134
            columns_noScore = [c for c in columns if c not in noScore]
135
136
            meta_data = meta_data.drop_duplicates(subset=columns_noScore)
137
        else:
138
            meta_data = meta_data_new
139
140
        meta_data.to_csv(path, index=False)
141
142
    def _read_func_metadata(self, model_func, _verb_):
143
        paths = self._get_func_data_names()
144
145
        meta_data_list = []
146
        for path in paths:
147
            meta_data = pd.read_csv(path)
148
            meta_data_list.append(meta_data)
149
            self.meta_data_found = True
150
151
        if len(meta_data_list) > 0:
152
            meta_data = pd.concat(meta_data_list, ignore_index=True)
153
154
            column_names = meta_data.columns
155
            score_name = [name for name in column_names if self.score_col_name in name]
156
157
            para = meta_data.drop(score_name, axis=1)
158
            score = meta_data[score_name]
159
160
            _verb_.load_meta_data()
161
            return para, score
162
163
        else:
164
            _verb_.no_meta_data(model_func)
165
            return None, None
166
167
    def _get_opt_meta_data(self):
168
        results_dict = {}
169
        para_list = []
170
        score_list = []
171
172
        for key in self.memory_dict.keys():
173
            pos = np.fromstring(key, dtype=int)
174
            para = self._space_.pos2para(pos)
175
            score = self.memory_dict[key]
176
177
            for key in para.keys():
178
                if (
179
                    not isinstance(para[key], int)
180
                    and not isinstance(para[key], float)
181
                    and not isinstance(para[key], str)
182
                ):
183
184
                    para_dill = dill.dumps(para[key])
185
                    para_hash = self._get_hash(para_dill)
186
187
                    with open(
188
                        self.func_path + str(para_hash) + ".pkl", "wb"
189
                    ) as pickle_file:
190
                        dill.dump(para_dill, pickle_file)
191
192
                    para[key] = para_hash
193
194
            if score != 0:
195
                para_list.append(para)
196
                score_list.append(score)
197
198
        results_dict["params"] = para_list
199
        results_dict["mean_test_score"] = score_list
200
201
        return results_dict
202
203
    def _load_data_into_memory(self, paras, scores):
204
        import tqdm
205
206
        paras = paras.replace(self._hash2obj())
207
        paras = self.para2pos(paras)
208
209
        for idx in tqdm.tqdm(range(paras.shape[0])):
210
            pos = paras.iloc[[idx]].values
211
            pos_str = pos.tostring()
212
213
            score = float(scores.values[idx])
214
            self.memory_dict[pos_str] = score
215
216
            if score > self.score_best:
217
                self.score_best = score
218
                self.pos_best = pos
219
220
    def para2pos(self, paras):
221
        paras = paras[self._space_.para_names]
222
        for pos_key in self._space_.search_space:
223
            paras[pos_key] = paras[pos_key].apply(
224
                self._space_.search_space[pos_key].index
225
            )
226
227
        return paras
228
229
    def _collect(self, _cand_):
230
        results_dict = self._get_opt_meta_data()
231
232
        para_pd = pd.DataFrame(results_dict["params"])
233
        metric_pd = pd.DataFrame(
234
            results_dict["mean_test_score"], columns=["mean_test_score"]
235
        )
236
237
        eval_time = pd.DataFrame(_cand_.eval_time, columns=["eval_time"])
238
        md_model = pd.concat(
239
            [para_pd, metric_pd, eval_time], axis=1, ignore_index=False
240
        )
241
242
        return md_model
243
244
    def _get_hash(self, object):
245
        return hashlib.sha1(object).hexdigest()
246
247
    def _get_func_str(self, func):
248
        return inspect.getsource(func)
249
250
    def _get_subdirs(self):
251
        subdirs = glob.glob(self.func_path + "*/")
252
253
        return subdirs
254
255
    def _get_func_data_names1(self):
256
        subdirs = self._get_subdirs()
257
258
        path_list = []
259
        for subdir in subdirs:
260
            paths = glob.glob(subdir + "*.csv")
261
            path_list = path_list + paths
262
263
        return path_list
264
265
    def _get_func_data_names(self):
266
        paths = glob.glob(
267
            self.func_path + (self.feature_hash + "_" + self.label_hash + "_.csv")
268
        )
269
270
        return paths
271
272
    def _read_dill(self, value):
273
        paths = self._get_pkl_hash(value)
274
        for path in paths:
275
            with open(path, "rb") as fp:
276
                value = dill.load(fp)
277
                value = dill.loads(value)
278
                break
279
280
        return value
281
282
    def _hash2obj(self):
283
        hash2obj_dict = {}
284
        para_hash_list = self._get_para_hash_list()
285
286
        for para_hash in para_hash_list:
287
            obj = self._read_dill(para_hash)
288
            hash2obj_dict[para_hash] = obj
289
290
        return hash2obj_dict
291
292
    def _get_para_hash_list(self):
293
        para_hash_list = []
294
        for key in self._space_.search_space.keys():
295
            values = self._space_.search_space[key]
296
297
            for value in values:
298
                if (
299
                    not isinstance(value, int)
300
                    and not isinstance(value, float)
301
                    and not isinstance(value, str)
302
                ):
303
304
                    para_dill = dill.dumps(value)
305
                    para_hash = self._get_hash(para_dill)
306
                    para_hash_list.append(para_hash)
307
308
        return para_hash_list
309
310
    def _get_pkl_hash(self, hash):
311
        paths = glob.glob(self.func_path + hash + "*.pkl")
312
313
        return paths
314
315
    def _get_file_path(self, model_func):
316
        if not os.path.exists(self.date_path):
317
            os.makedirs(self.date_path)
318
319
        return self.func_path + (self.feature_hash + "_" + self.label_hash + "_.csv")
320