Passed
Push — master ( 7dc8c7...0f59c2 )
by Simon
03:25
created

SMBO.init_warm_start_smbo()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 5
nop 1
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
import numpy as np
7
8
np.seterr(divide="ignore", invalid="ignore")
9
10
from ..base_optimizer import BaseOptimizer
11
from ...search import Search
12
13
14
class SMBO(BaseOptimizer, Search):
15
    def __init__(
16
        self,
17
        search_space,
18
        initialize={"grid": 4, "random": 2, "vertices": 4},
19
        warm_start_smbo=None,
20
    ):
21
        super().__init__(search_space, initialize)
22
        self.warm_start_smbo = warm_start_smbo
23
24
        self.X_sample = []
25
        self.Y_sample = []
26
27
        self.all_pos_comb = self._all_possible_pos()
28
29
    def init_warm_start_smbo(self):
30
        if self.warm_start_smbo is not None:
31
            (X_sample_values, Y_sample) = self.warm_start_smbo
32
33
            self.X_sample = self.conv.values2positions(X_sample_values)
34
            self.Y_sample = list(Y_sample)
35
36
    def track_X_sample(func):
37
        def wrapper(self, *args, **kwargs):
38
            pos = func(self, *args, **kwargs)
39
            self.X_sample.append(pos)
40
            return pos
41
42
        return wrapper
43
44
    def _all_possible_pos(self):
45
        pos_space = []
46
        for dim_ in self.conv.max_positions:
47
            pos_space.append(np.arange(dim_))
48
49
        n_dim = len(pos_space)
50
        return np.array(np.meshgrid(*pos_space)).T.reshape(-1, n_dim)
51
52
    @track_X_sample
53
    def init_pos(self, pos):
54
        super().init_pos(pos)
55
        return pos
56