Passed
Push — master ( 9aa23e...87bd68 )
by Simon
01:40
created

SearchSpace.pos2para()   A

Complexity

Conditions 3

Size

Total Lines 8
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 7
dl 0
loc 8
rs 10
c 0
b 0
f 0
cc 3
nop 2
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import random
6
import numpy as np
7
8
9
class SearchSpace:
10
    def __init__(self, _core_, model_nr):
11
        self.search_config = _core_.search_config
12
        self.warm_start = _core_.warm_start
13
        self.scatter_init = _core_.scatter_init
14
        self.model_nr = model_nr
15
16
        self.memory = {}
17
18
    def pos_space_limit(self):
19
        dim = []
20
21
        for pos_key in self.para_space:
22
            dim.append(len(self.para_space[pos_key]) - 1)
23
24
        self.dim = np.array(dim)
25
26
    def create_searchspace(self):
27
        """
28
        para_space = {}
29
30
        for para_key in search_config_temp.keys():
31
32
            for param_str in search_config_temp[para_key].keys():
33
                new_param_str = para_key + "." + param_str
34
35
                para_space[new_param_str] = search_config_temp[para_key][param_str]
36
37
        """
38
39
        self.para_space = self.search_config[list(self.search_config)[self.model_nr]]
40
41
        self.pos_space_limit()
42
43
    def get_random_pos(self):
44
        pos_new = np.random.uniform(np.zeros(self.dim.shape), self.dim, self.dim.shape)
45
        pos = np.rint(pos_new)
46
47
        return pos
48
49
    def get_random_pos_scalar(self, hyperpara_name):
50
        n_para_values = len(self.para_space[hyperpara_name])
51
        pos = random.randint(0, n_para_values - 1)
52
53
        return pos
54
55
    def pos2para(self, pos):
56
        if len(self.para_space.keys()) == pos.size:
57
            values_dict = {}
58
            for i, key in enumerate(self.para_space.keys()):
59
                pos_ = int(pos[i])
60
                values_dict[key] = list(self.para_space[key])[pos_]
61
62
            return values_dict
63