Passed
Push — master ( d5b48a...5a31d8 )
by Simon
01:07
created

optimization_metadata.memory_dump   A

Complexity

Total Complexity 26

Size/Duplication

Total Lines 137
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 26
eloc 91
dl 0
loc 137
rs 10
c 0
b 0
f 0

10 Methods

Rating   Name   Duplication   Size   Complexity  
A MemoryDump.__init__() 0 2 1
A MemoryDump.memory_dict2dataframe() 0 23 4
A MemoryDump._search_space_types() 0 17 5
A MemoryDump.dump_dataframe() 0 19 2
A MemoryDump.dump_object() 0 3 2
A MemoryDump._get_file_path() 0 5 2
A MemoryDump.save_search_data() 0 10 1
A MemoryDump.dump_dict() 0 3 2
A MemoryDump.save_dataset_info() 0 8 3
A MemoryDump._create_hash_list() 0 17 4
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import os
6
import json
7
import dill
8
import inspect
9
import random
10
11
import numpy as np
12
import pandas as pd
13
14
from operator import itemgetter
15
from .memory_io import MemoryIO
16
from .dataset_features import get_dataset_features
17
from .utils import object_hash
18
19
20
class MemoryDump(MemoryIO):
21
    def __init__(self, X, y, model, search_space):
22
        super().__init__(X, y, model, search_space)
23
24
    def dump_object(self, _object, path, name):
25
        with open(path + name, "wb") as dill_file:
26
            dill.dump(_object, dill_file)
27
28
    def dump_dict(self, _dict, path):
29
        with open(path, "w") as json_file:
30
            json.dump(_dict, json_file, indent=4)
31
32
    def dump_dataframe(self, _dataframe, path, name):
33
        if os.path.exists(path + name):
34
            _dataframe_old = pd.read_csv(path + name)
35
36
            assert len(_dataframe_old.columns) == len(
37
                _dataframe.columns
38
            ), "Warning meta data dimensionality does not match"
39
40
            _dataframe_final = _dataframe_old.append(_dataframe)
41
42
            columns = list(_dataframe_final.columns)
43
            noScore = ["_score_", "cv_default_score", "eval_time", "run"]
44
            columns_noScore = [c for c in columns if c not in noScore]
45
46
            _dataframe_final = _dataframe_final.drop_duplicates(subset=columns_noScore)
47
        else:
48
            _dataframe_final = _dataframe
49
50
        _dataframe_final.to_csv(path + name, index=False)
51
52
    def memory_dict2dataframe(self, memory_dict, object2hash=True):
53
        tuple_list = list(memory_dict.keys())
54
        result_list = list(memory_dict.values())
55
56
        results_df = pd.DataFrame(result_list)
57
        np_pos = np.array(tuple_list)
58
59
        para_dict = {}
60
        for i, key in zip(range(np_pos.shape[1]), self.search_space):
61
            np_pos_ = list(np_pos[:, i])
62
            search_space_list = list(self.search_space[key])
63
64
            if self.search_space_types[key] == "object" and object2hash:
65
                search_space_list = self.object_hash_dict[key]
66
67
            search_space_list = np.array(search_space_list)
68
            para_list = search_space_list[np_pos_]
69
70
            # para_list = list(itemgetter(*np_pos_)(search_space_list))
71
            para_dict[key] = para_list
72
73
        para_df = pd.DataFrame(para_dict)
74
        return pd.concat([para_df, results_df], axis=1)
75
76
    def save_dataset_info(self, path, name):
77
        data_features = get_dataset_features(self.X, self.y)
78
79
        if not os.path.exists(path):
80
            os.makedirs(path, exist_ok=True)
81
82
        with open(path + name, "w") as f:
83
            json.dump(data_features, f, indent=4)
84
85
    def save_search_data(self, memory_dict, path, name):
86
        self._search_space_types()
87
        self._create_hash_list(path)
88
        meta_data_df = self.memory_dict2dataframe(memory_dict)
89
90
        # meta_data_df["run"] = self.datetime
91
92
        self.dump_dataframe(meta_data_df, path, name)
93
94
        print("\nMeta data saved in:\n", self.path)
95
96
    def _get_file_path(self, model_func):
97
        if not os.path.exists(self.date_path):
98
            os.makedirs(self.date_path)
99
100
        return self.model_path + self.meta_data_name
101
102
    def _search_space_types(self):
103
        self.search_space_types = {}
104
        for key in self.search_space.keys():
105
            search_space_list = list(self.search_space[key])
106
107
            # sampled_list = random.sample(aList, 3)
108
109
            value = search_space_list[0]
110
111
            if isinstance(value, int):
112
                self.search_space_types[key] = "int"
113
            elif isinstance(value, float):
114
                self.search_space_types[key] = "float"
115
            elif isinstance(value, str):
116
                self.search_space_types[key] = "str"
117
            else:
118
                self.search_space_types[key] = "object"
119
120
    def _create_hash_list(self, path):
121
        self.object_hash_dict = {}
122
123
        for key in self.search_space.keys():
124
            if self.search_space_types[key] == "object":
125
                search_space_list = list(self.search_space[key])
126
127
                object_hash_list = []
128
129
                for value in search_space_list:
130
                    para_dill = dill.dumps(value)
131
                    para_hash = object_hash(para_dill)
132
133
                    self.dump_object(para_dill, path, str(para_hash) + ".pkl")
134
                    object_hash_list.append(para_hash)
135
136
                self.object_hash_dict[key] = object_hash_list
137