Passed
Push — master ( b3cc5b...cdf567 )
by Simon
01:59
created

SearchSpace._read_dill()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 6
rs 10
c 0
b 0
f 0
cc 2
nop 3
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import dill
6
import random
7
import numpy as np
8
9
10
class SearchSpace:
11
    def __init__(self, _core_, model_nr):
12
        self.search_space = _core_.search_config[list(_core_.search_config)[model_nr]]
13
        self.pos_space_limit()
14
        self.init_type = None
15
16
        if _core_.init_config:
17
            self.init_para = _core_.init_config[list(_core_.init_config)[model_nr]]
18
19
            if list(self.init_para.keys())[0] == list(self.search_space.keys())[0]:
20
                self.init_type = "warm_start"
21
            elif list(self.init_para.keys())[0] == "scatter_init":
22
                self.init_type = "scatter_init"
23
24
    def pos_space_limit(self):
25
        dim = []
26
27
        for pos_key in self.search_space:
28
            dim.append(len(self.search_space[pos_key]) - 1)
29
30
        self.dim = np.array(dim)
31
32
    def get_random_pos(self):
33
        pos_new = np.random.uniform(np.zeros(self.dim.shape), self.dim, self.dim.shape)
34
        pos = np.rint(pos_new).astype(int)
35
36
        return pos
37
38
    def get_random_pos_scalar(self, hyperpara_name):
39
        n_para_values = len(self.search_space[hyperpara_name])
40
        pos = random.randint(0, n_para_values - 1)
41
42
        return pos
43
44
    def _read_dill(self, value, path):
45
        with open(path, "rb") as fp:
46
            value = dill.load(fp)
47
            value = dill.loads(value)
48
49
        return value
50
51
    def para2pos(self, para, _get_pkl_hash):
52
        pos_list = []
53
54
        for pos_key in self.search_space:
55
            value = para[[pos_key]].values[0][0]
56
57
            if isinstance(value, str):
58
59
                if len(value) == 40:
60
                    paths = _get_pkl_hash(value)
61
                    for path in paths:
62
                        value = self._read_dill(value, path)
63
64
                        if not isinstance(value, str):
65
                            break
66
67
            pos = self.search_space[pos_key].index(value)
68
            pos_list.append(pos)
69
70
        return np.array(pos_list)
71
72
    def pos2para(self, pos):
73
        if len(self.search_space.keys()) == pos.size:
74
            values_dict = {}
75
            for i, key in enumerate(self.search_space.keys()):
76
                pos_ = int(pos[i])
77
                values_dict[key] = list(self.search_space[key])[pos_]
78
79
            return values_dict
80