Passed
Push — master ( 41483f...5037dc )
by Simon
01:18
created

e()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import random
6
import numpy as np
7
8
9
class SearchSpace:
10
    def __init__(self, _core_, model_nr):
11
        self.search_space = _core_.search_config[list(_core_.search_config)[model_nr]]
12
        self.pos_space_limit()
13
14
        self.init_para = _core_.init_config[list(_core_.init_config)[model_nr]]
15
        self.memory = {}
16
17
        if list(self.init_para.keys())[0] == list(self.search_space.keys())[0]:
18
            self.init_type = "warm_start"
19
        elif list(self.init_para.keys())[0] == "scatter_init":
20
            self.init_type = "scatter_init"
21
        else:
22
            self.init_type = None
23
24
    def load_memory(self, para, score):
25
        if para is None or score is None:
26
            return
27
        for idx in range(para.shape[0]):
28
            pos = self.para2pos(para.iloc[[idx]])
29
            pos_str = pos.tostring()
30
            self.memory[pos_str] = float(score.values[idx])
31
32
    def pos_space_limit(self):
33
        dim = []
34
35
        for pos_key in self.search_space:
36
            dim.append(len(self.search_space[pos_key]) - 1)
37
38
        self.dim = np.array(dim)
39
40
    def get_random_pos(self):
41
        pos_new = np.random.uniform(np.zeros(self.dim.shape), self.dim, self.dim.shape)
42
        pos = np.rint(pos_new).astype(int)
43
44
        return pos
45
46
    def get_random_pos_scalar(self, hyperpara_name):
47
        n_para_values = len(self.search_space[hyperpara_name])
48
        pos = random.randint(0, n_para_values - 1)
49
50
        return pos
51
52
    def para2pos(self, para):
53
        pos_list = []
54
55
        for pos_key in self.search_space:
56
            value = para[[pos_key]].values
57
58
            pos = self.search_space[pos_key].index(value)
59
            pos_list.append(pos)
60
61
        return np.array(pos_list)
62
63
    def pos2para(self, pos):
64
        if len(self.search_space.keys()) == pos.size:
65
            values_dict = {}
66
            for i, key in enumerate(self.search_space.keys()):
67
                pos_ = int(pos[i])
68
                values_dict[key] = list(self.search_space[key])[pos_]
69
70
            return values_dict
71