Completed
Push — master ( b4252e...a98e6f )
by Simon
01:25
created

hyperactive.memory.memory_helper.merge_model_IDs()   B

Complexity

Conditions 5

Size

Total Lines 27
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 19
nop 2
dl 0
loc 27
rs 8.9833
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import os
6
import json
7
import shutil
8
import hashlib
9
import inspect
10
11
12
current_path = os.path.realpath(__file__)
13
meta_learn_path, _ = current_path.rsplit("/", 1)
14
meta_path = meta_learn_path + "/meta_data/"
15
16
"""
17
def get_best_models(X, y):
18
    # TODO: model_dict   key:model   value:score
19
20
    return model_dict
21
22
23
def get_model_search_config(model):
24
    # TODO
25
    return search_config
26
27
28
def get_model_init_config(model):
29
    # TODO
30
    return init_config
31
"""
32
33
34
def delete_model(model):
35
    model_hash = _get_model_hash(model)
36
    path = meta_path + str(model_hash)
37
38
    if os.path.exists(path) and os.path.isdir(path):
39
        shutil.rmtree(meta_path + str(model_hash))
40
        print("Model data successfully removed")
41
    else:
42
        print("Model data not found in memory")
43
44
45
def delete_model_dataset(model, X, y):
46
    csv_file = _get_file_path(model, X, y)
47
48
    if os.path.exists(csv_file):
49
        os.remove(csv_file)
50
        print("Model data successfully removed")
51
    else:
52
        print("Model data not found in memory")
53
54
55
def connect_model_IDs(model1, model2):
56
    # do checks if search space has same dim
57
58
    with open(meta_path + "model_connections.json") as f:
59
        data = json.load(f)
60
61
    model1_hash = _get_model_hash(model1)
62
    model2_hash = _get_model_hash(model2)
63
64
    if model1_hash in data:
65
        key_model = model1_hash
66
        value_model = model2_hash
67
        data = _connect_key2value(data, key_model, value_model)
68
    else:
69
        data[model1_hash] = [model2_hash]
70
        print("IDs successfully connected")
71
72
    if model2_hash in data:
73
        key_model = model2_hash
74
        value_model = model1_hash
75
        data = _connect_key2value(data, key_model, value_model)
76
    else:
77
        data[model2_hash] = [model1_hash]
78
        print("IDs successfully connected")
79
80
    with open(meta_path + "model_connections.json", "w") as f:
81
        json.dump(data, f, indent=4)
82
83
84
def _connect_key2value(data, key_model, value_model):
85
    if value_model in data[key_model]:
86
        print("IDs of models are already connected")
87
    else:
88
        data[key_model].append(value_model)
89
        print("IDs successfully connected")
90
91
    return data
92
93
94
def _split_key_value(data, key_model, value_model):
95
    if value_model in data[key_model]:
96
        data[key_model].remove(value_model)
97
98
        if len(data[key_model]) == 0:
99
            del data[key_model]
100
        print("ID connection successfully deleted")
101
    else:
102
        print("IDs of models are already connected")
103
104
    return data
105
106
107
def split_model_IDs(model1, model2):
108
    # TODO: do checks if search space has same dim
109
110
    with open(meta_path + "model_connections.json") as f:
111
        data = json.load(f)
112
113
    model1_hash = _get_model_hash(model1)
114
    model2_hash = _get_model_hash(model2)
115
116
    if model1_hash in data:
117
        key_model = model1_hash
118
        value_model = model2_hash
119
        data = _split_key_value(data, key_model, value_model)
120
    else:
121
        print("IDs of models are not connected")
122
123
    if model2_hash in data:
124
        key_model = model2_hash
125
        value_model = model1_hash
126
        data = _split_key_value(data, key_model, value_model)
127
    else:
128
        print("IDs of models are not connected")
129
130
    with open(meta_path + "model_connections.json", "w") as f:
131
        json.dump(data, f, indent=4)
132
133
134
def _get_file_path(model, X, y):
135
    func_path_ = _get_model_hash(model) + "/"
136
    func_path = meta_path + func_path_
137
138
    feature_hash = _get_hash(X)
139
    label_hash = _get_hash(y)
140
141
    return func_path + (feature_hash + "_" + label_hash + "_.csv")
142
143
144
def _get_model_hash(model):
145
    return str(_get_hash(_get_func_str(model).encode("utf-8")))
146
147
148
def _get_func_str(func):
149
    return inspect.getsource(func)
150
151
152
def _get_hash(object):
153
    return hashlib.sha1(object).hexdigest()
154