Passed
Push — master ( 916d9f...163de3 )
by Simon
02:25
created

hyperactive.search_space.search_space.SearchSpace.para2pos()   B

Complexity

Conditions 6

Size

Total Lines 20
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 14
dl 0
loc 20
rs 8.6666
c 0
b 0
f 0
cc 6
nop 3
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import dill
6
import random
7
import numpy as np
8
9
10
class SearchSpace:
11
    def __init__(self, _core_, model_nr):
12
        self.search_space = _core_.search_config[list(_core_.search_config)[model_nr]]
13
        self.pos_space_limit()
14
        self.init_type = None
15
16
        self.para_names = list(self.search_space.keys())
17
18
        if _core_.init_config:
19
            self.init_para = _core_.init_config[list(_core_.init_config)[model_nr]]
20
21
            if list(self.init_para.keys())[0] == list(self.search_space.keys())[0]:
22
                self.init_type = "warm_start"
23
            elif list(self.init_para.keys())[0] == "scatter_init":
24
                self.init_type = "scatter_init"
25
26
    def pos_space_limit(self):
27
        dim = []
28
29
        for pos_key in self.search_space:
30
            dim.append(len(self.search_space[pos_key]) - 1)
31
32
        self.dim = np.array(dim)
33
34
    def get_random_pos(self):
35
        pos_new = np.random.uniform(np.zeros(self.dim.shape), self.dim, self.dim.shape)
36
        pos = np.rint(pos_new).astype(int)
37
38
        return pos
39
40
    def get_random_pos_scalar(self, hyperpara_name):
41
        n_para_values = len(self.search_space[hyperpara_name])
42
        pos = random.randint(0, n_para_values - 1)
43
44
        return pos
45
46
    def pos2para(self, pos):
47
        if len(self.search_space.keys()) == pos.size:
48
            values_dict = {}
49
            for i, key in enumerate(self.search_space.keys()):
50
                pos_ = int(pos[i])
51
                values_dict[key] = list(self.search_space[key])[pos_]
52
53
            return values_dict
54