Passed
Push — master ( 6b2285...b7fe1a )
by Simon
02:03 queued 11s
created

LongTermMemory._get_file_path()   A

Complexity

Conditions 2

Size

Total Lines 18
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 15
dl 0
loc 18
rs 9.65
c 0
b 0
f 0
cc 2
nop 2
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import os
6
import glob
7
import hashlib
8
import inspect
9
10
import numpy as np
11
import pandas as pd
12
13
14
class Memory:
15
    def __init__(self, _space_, _main_args_):
16
        self._space_ = _space_
17
        self._main_args_ = _main_args_
18
19
        self.pos_best = None
20
        self.score_best = -np.inf
21
22
        self.memory_type = _main_args_.memory
23
        self.memory_dict = {}
24
25
        self.meta_data_found = False
26
27
28
class ShortTermMemory(Memory):
29
    def __init__(self, _space_, _main_args_):
30
        super().__init__(_space_, _main_args_)
31
32
33
class LongTermMemory(Memory):
34
    def __init__(self, _space_, _main_args_):
35
        super().__init__(_space_, _main_args_)
36
37
        self.score_col_name = "mean_test_score"
38
39
        current_path = os.path.realpath(__file__)
40
        meta_learn_path, _ = current_path.rsplit("/", 1)
41
        self.meta_data_path = meta_learn_path + "/meta_data/"
42
43
    def load_memory(self, model_func):
44
        para, score = self._read_func_metadata(model_func)
45
        if para is None or score is None:
46
            return
47
48
        self._load_data_into_memory(para, score)
49
50
    def save_memory(self, _main_args_, _cand_):
51
        meta_data = self._collect()
52
        path = self._get_file_path(_cand_.func_)
53
        self._save_toCSV(meta_data, path)
54
55
    def _save_toCSV(self, meta_data_new, path):
56
        if os.path.exists(path):
57
            meta_data_old = pd.read_csv(path)
58
            meta_data = meta_data_old.append(meta_data_new)
59
60
            columns = list(meta_data.columns)
61
            noScore = ["mean_test_score", "cv_default_score"]
62
            columns_noScore = [c for c in columns if c not in noScore]
63
64
            meta_data = meta_data.drop_duplicates(subset=columns_noScore)
65
        else:
66
            meta_data = meta_data_new
67
68
        meta_data.to_csv(path, index=False)
69
70
    def _read_func_metadata(self, model_func):
71
        paths = glob.glob(self._get_func_file_paths(model_func))
72
73
        meta_data_list = []
74
        for path in paths:
75
            meta_data = pd.read_csv(path)
76
            meta_data_list.append(meta_data)
77
            self.meta_data_found = True
78
79
        if len(meta_data_list) > 0:
80
            meta_data = pd.concat(meta_data_list, ignore_index=True)
81
82
            column_names = meta_data.columns
83
            score_name = [name for name in column_names if self.score_col_name in name]
84
85
            para = meta_data.drop(score_name, axis=1)
86
            score = meta_data[score_name]
87
88
            print("Loading meta data successful")
89
            return para, score
90
91
        else:
92
            print("Warning: No meta data found for following function:", model_func)
93
            return None, None
94
95
    def _get_opt_meta_data(self):
96
        results_dict = {}
97
        para_list = []
98
        score_list = []
99
100
        for key in self.memory_dict.keys():
101
            pos = np.fromstring(key, dtype=int)
102
            para = self._space_.pos2para(pos)
103
            score = self.memory_dict[key]
104
105
            if score != 0:
106
                para_list.append(para)
107
                score_list.append(score)
108
109
        results_dict["params"] = para_list
110
        results_dict["mean_test_score"] = score_list
111
112
        return results_dict
113
114
    def _load_data_into_memory(self, paras, scores):
115
        for idx in range(paras.shape[0]):
116
            pos = self._space_.para2pos(paras.iloc[[idx]])
117
            pos_str = pos.tostring()
118
119
            score = float(scores.values[idx])
120
            self.memory_dict[pos_str] = score
121
122
            if score > self.score_best:
123
                self.score_best = score
124
                self.pos_best = pos
125
126
    def _get_para(self):
127
        results_dict = self._get_opt_meta_data()
128
129
        return pd.DataFrame(results_dict["params"])
130
131
    def _get_score(self):
132
        results_dict = self._get_opt_meta_data()
133
        return pd.DataFrame(
134
            results_dict["mean_test_score"], columns=["mean_test_score"]
135
        )
136
137
    def _collect(self):
138
        para_pd = self._get_para()
139
        # md_model = para_pd.reindex(sorted(para_pd.columns), axis=1)
140
        metric_pd = self._get_score()
141
142
        md_model = pd.concat([para_pd, metric_pd], axis=1, ignore_index=False)
143
144
        return md_model
145
146
    def _get_hash(self, object):
147
        return hashlib.sha1(object).hexdigest()
148
149
    def _get_func_str(self, func):
150
        return inspect.getsource(func)
151
152
    def _get_func_file_paths(self, model_func):
153
        func_str = self._get_func_str(model_func)
154
        self.func_path = self._get_hash(func_str.encode("utf-8")) + "/"
155
156
        directory = self.meta_data_path + self.func_path
157
        if not os.path.exists(directory):
158
            os.makedirs(directory, exist_ok=True)
159
160
        return directory + ("metadata" + "*" + "__.csv")
161
162
    def _get_file_path(self, model_func):
163
        func_str = self._get_func_str(model_func)
164
        feature_hash = self._get_hash(self._main_args_.X)
165
        label_hash = self._get_hash(self._main_args_.y)
166
167
        self.func_path = self._get_hash(func_str.encode("utf-8")) + "/"
168
169
        directory = self.meta_data_path + self.func_path
170
        if not os.path.exists(directory):
171
            os.makedirs(directory)
172
173
        return directory + (
174
            "metadata"
175
            + "__feature_hash="
176
            + feature_hash
177
            + "__label_hash="
178
            + label_hash
179
            + "__.csv"
180
        )
181