Passed
Push — master ( 9ff666...1a4396 )
by Simon
03:24
created

SearchSpace.para2pos()   A

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 7
dl 0
loc 10
rs 10
c 0
b 0
f 0
cc 2
nop 2
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import random
6
import numpy as np
7
8
9
class SearchSpace:
10
    def __init__(self, _core_, model_nr):
11
        self.search_space = _core_.search_config[list(_core_.search_config)[model_nr]]
12
        self.pos_space_limit()
13
        self.init_type = None
14
15
        if _core_.init_config:
16
            self.init_para = _core_.init_config[list(_core_.init_config)[model_nr]]
17
18
            if list(self.init_para.keys())[0] == list(self.search_space.keys())[0]:
19
                self.init_type = "warm_start"
20
            elif list(self.init_para.keys())[0] == "scatter_init":
21
                self.init_type = "scatter_init"
22
23
    def pos_space_limit(self):
24
        dim = []
25
26
        for pos_key in self.search_space:
27
            dim.append(len(self.search_space[pos_key]) - 1)
28
29
        self.dim = np.array(dim)
30
31
    def get_random_pos(self):
32
        pos_new = np.random.uniform(np.zeros(self.dim.shape), self.dim, self.dim.shape)
33
        pos = np.rint(pos_new).astype(int)
34
35
        return pos
36
37
    def get_random_pos_scalar(self, hyperpara_name):
38
        n_para_values = len(self.search_space[hyperpara_name])
39
        pos = random.randint(0, n_para_values - 1)
40
41
        return pos
42
43
    def para2pos(self, para):
44
        pos_list = []
45
46
        for pos_key in self.search_space:
47
            value = para[[pos_key]].values
48
49
            pos = self.search_space[pos_key].index(value)
50
            pos_list.append(pos)
51
52
        return np.array(pos_list)
53
54
    def pos2para(self, pos):
55
        if len(self.search_space.keys()) == pos.size:
56
            values_dict = {}
57
            for i, key in enumerate(self.search_space.keys()):
58
                pos_ = int(pos[i])
59
                values_dict[key] = list(self.search_space[key])[pos_]
60
61
            return values_dict
62