Passed
Push — master ( b478a2...16a37b )
by Simon
01:57 queued 10s
created

hyperactive.search_space.SearchSpace.__init__()   A

Complexity

Conditions 4

Size

Total Lines 14
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 11
nop 3
dl 0
loc 14
rs 9.85
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import random
6
import numpy as np
7
8
9
class SearchSpace:
10
    def __init__(self, _core_, model_nr):
11
        self.search_space = _core_.search_config[list(_core_.search_config)[model_nr]]
12
        self.pos_space_limit()
13
        self.init_type = None
14
15
        self.para_names = list(self.search_space.keys())
16
17
        if _core_.init_config:
18
            self.init_para = _core_.init_config[list(_core_.init_config)[model_nr]]
19
20
            if list(self.init_para.keys())[0] == list(self.search_space.keys())[0]:
21
                self.init_type = "warm_start"
22
            elif list(self.init_para.keys())[0] == "scatter_init":
23
                self.init_type = "scatter_init"
24
25
    def pos_space_limit(self):
26
        dim = []
27
28
        for pos_key in self.search_space:
29
            dim.append(len(self.search_space[pos_key]) - 1)
30
31
        self.dim = np.array(dim)
32
33
    def get_random_pos(self):
34
        pos_new = np.random.uniform(np.zeros(self.dim.shape), self.dim, self.dim.shape)
35
        pos = np.rint(pos_new).astype(int)
36
37
        return pos
38
39
    def get_random_pos_scalar(self, hyperpara_name):
40
        n_para_values = len(self.search_space[hyperpara_name])
41
        pos = random.randint(0, n_para_values - 1)
42
43
        return pos
44
45
    def pos2para(self, pos):
46
        if len(self.search_space.keys()) == pos.size:
47
            values_dict = {}
48
            for i, key in enumerate(self.search_space.keys()):
49
                pos_ = int(pos[i])
50
                values_dict[key] = list(self.search_space[key])[pos_]
51
52
            return values_dict
53