Passed
Push — master ( 9aa23e...87bd68 )
by Simon
01:40
created

hyperactive.search_space.search_space   A

Complexity

Total Complexity 9

Size/Duplication

Total Lines 63
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 9
eloc 34
dl 0
loc 63
rs 10
c 0
b 0
f 0

6 Methods

Rating   Name   Duplication   Size   Complexity  
A SearchSpace.pos_space_limit() 0 7 2
A SearchSpace.__init__() 0 7 1
A SearchSpace.get_random_pos_scalar() 0 5 1
A SearchSpace.create_searchspace() 0 16 1
A SearchSpace.get_random_pos() 0 5 1
A SearchSpace.pos2para() 0 8 3
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import random
6
import numpy as np
7
8
9
class SearchSpace:
10
    def __init__(self, _core_, model_nr):
11
        self.search_config = _core_.search_config
12
        self.warm_start = _core_.warm_start
13
        self.scatter_init = _core_.scatter_init
14
        self.model_nr = model_nr
15
16
        self.memory = {}
17
18
    def pos_space_limit(self):
19
        dim = []
20
21
        for pos_key in self.para_space:
22
            dim.append(len(self.para_space[pos_key]) - 1)
23
24
        self.dim = np.array(dim)
25
26
    def create_searchspace(self):
27
        """
28
        para_space = {}
29
30
        for para_key in search_config_temp.keys():
31
32
            for param_str in search_config_temp[para_key].keys():
33
                new_param_str = para_key + "." + param_str
34
35
                para_space[new_param_str] = search_config_temp[para_key][param_str]
36
37
        """
38
39
        self.para_space = self.search_config[list(self.search_config)[self.model_nr]]
40
41
        self.pos_space_limit()
42
43
    def get_random_pos(self):
44
        pos_new = np.random.uniform(np.zeros(self.dim.shape), self.dim, self.dim.shape)
45
        pos = np.rint(pos_new)
46
47
        return pos
48
49
    def get_random_pos_scalar(self, hyperpara_name):
50
        n_para_values = len(self.para_space[hyperpara_name])
51
        pos = random.randint(0, n_para_values - 1)
52
53
        return pos
54
55
    def pos2para(self, pos):
56
        if len(self.para_space.keys()) == pos.size:
57
            values_dict = {}
58
            for i, key in enumerate(self.para_space.keys()):
59
                pos_ = int(pos[i])
60
                values_dict[key] = list(self.para_space[key])[pos_]
61
62
            return values_dict
63