Passed
Push — master ( eeab31...6c8f7d )
by Simon
02:03
created

SearchSpace.pos2para()   A

Complexity

Conditions 3

Size

Total Lines 8
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 7
dl 0
loc 8
rs 10
c 0
b 0
f 0
cc 3
nop 2
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import random
6
import numpy as np
7
8
9
class SearchSpace:
10
    def __init__(self, _config_):
11
        self.search_config = _config_.search_config
12
        self.warm_start = _config_.warm_start
13
        self.scatter_init = _config_.scatter_init
14
15
        self.memory = {}
16
17
    def pos_space_limit(self):
18
        dim = []
19
20
        for pos_key in self.para_space:
21
            dim.append(len(self.para_space[pos_key]) - 1)
22
23
        self.dim = np.array(dim)
24
25
    def create_kerasSearchSpace(self):
26
        """
27
        para_space = {}
28
29
        for para_key in search_config_temp.keys():
30
31
            for param_str in search_config_temp[para_key].keys():
32
                new_param_str = para_key + "." + param_str
33
34
                para_space[new_param_str] = search_config_temp[para_key][param_str]
35
36
        """
37
38
        self.para_space = self.search_config[list(self.search_config)[0]]
39
40
        self.pos_space_limit()
41
42
    def get_random_pos(self):
43
        pos_new = np.random.uniform(np.zeros(self.dim.shape), self.dim, self.dim.shape)
44
        pos = np.rint(pos_new)
45
46
        # n_zeros = [0] * len(self.dim)
47
        # pos = np.clip(pos_new_int, n_zeros, self.dim)
48
        return pos
49
50
    def get_random_pos_scalar(self, hyperpara_name):
51
        n_para_values = len(self.para_space[hyperpara_name])
52
        pos = random.randint(0, n_para_values - 1)
53
54
        return pos
55
56
    def pos2para(self, pos):
57
        if len(self.para_space.keys()) == pos.size:
58
            values_dict = {}
59
            for i, key in enumerate(self.para_space.keys()):
60
                pos_ = int(pos[i])
61
                values_dict[key] = list(self.para_space[key])[pos_]
62
63
            return values_dict
64