Passed
Push — master ( 61a8e6...a7d091 )
by Simon
03:21
created

SearchSpace.para2pos()   A

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 7
dl 0
loc 10
rs 10
c 0
b 0
f 0
cc 2
nop 2
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import random
6
import numpy as np
7
8
9
class SearchSpace:
10
    def __init__(self, _core_, model_nr):
11
        self.search_config = _core_.search_config
12
        self.warm_start = _core_.warm_start
13
        self.scatter_init = _core_.scatter_init
14
        self.model_nr = model_nr
15
16
        self.memory = {}
17
18
    def load_memory(self, para, score):
19
        for idx in range(para.shape[0]):
20
            pos = self.para2pos(para.iloc[[idx]])
21
            pos_str = pos.tostring()
22
            self.memory[pos_str] = float(score.values[idx])
23
24
    def pos_space_limit(self):
25
        dim = []
26
27
        for pos_key in self.para_space:
28
            dim.append(len(self.para_space[pos_key]) - 1)
29
30
        self.dim = np.array(dim)
31
32
    def create_searchspace(self):
33
        self.para_space = self.search_config[list(self.search_config)[self.model_nr]]
34
        self.pos_space_limit()
35
36
    def get_random_pos(self):
37
        pos_new = np.random.uniform(np.zeros(self.dim.shape), self.dim, self.dim.shape)
38
        pos = np.rint(pos_new).astype(int)
39
40
        return pos
41
42
    def get_random_pos_scalar(self, hyperpara_name):
43
        n_para_values = len(self.para_space[hyperpara_name])
44
        pos = random.randint(0, n_para_values - 1)
45
46
        return pos
47
48
    def para2pos(self, para):
49
        pos_list = []
50
51
        for pos_key in self.para_space:
52
            value = para[[pos_key]].values
53
54
            pos = self.para_space[pos_key].index(value)
55
            pos_list.append(pos)
56
57
        return np.array(pos_list)
58
59
    def pos2para(self, pos):
60
        if len(self.para_space.keys()) == pos.size:
61
            values_dict = {}
62
            for i, key in enumerate(self.para_space.keys()):
63
                pos_ = int(pos[i])
64
                values_dict[key] = list(self.para_space[key])[pos_]
65
66
            return values_dict
67