Passed
Push — master ( 7dc8c7...0f59c2 )
by Simon
03:25
created

gradient_free_optimizers.optimizers.sequence_model.smbo   A

Complexity

Total Complexity 7

Size/Duplication

Total Lines 56
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 37
dl 0
loc 56
rs 10
c 0
b 0
f 0
wmc 7

5 Methods

Rating   Name   Duplication   Size   Complexity  
A SMBO.init_pos() 0 4 1
A SMBO.track_X_sample() 0 7 1
A SMBO._all_possible_pos() 0 7 2
A SMBO.init_warm_start_smbo() 0 6 2
A SMBO.__init__() 0 13 1
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
import numpy as np
7
8
np.seterr(divide="ignore", invalid="ignore")
9
10
from ..base_optimizer import BaseOptimizer
11
from ...search import Search
12
13
14
class SMBO(BaseOptimizer, Search):
15
    def __init__(
16
        self,
17
        search_space,
18
        initialize={"grid": 4, "random": 2, "vertices": 4},
19
        warm_start_smbo=None,
20
    ):
21
        super().__init__(search_space, initialize)
22
        self.warm_start_smbo = warm_start_smbo
23
24
        self.X_sample = []
25
        self.Y_sample = []
26
27
        self.all_pos_comb = self._all_possible_pos()
28
29
    def init_warm_start_smbo(self):
30
        if self.warm_start_smbo is not None:
31
            (X_sample_values, Y_sample) = self.warm_start_smbo
32
33
            self.X_sample = self.conv.values2positions(X_sample_values)
34
            self.Y_sample = list(Y_sample)
35
36
    def track_X_sample(func):
37
        def wrapper(self, *args, **kwargs):
38
            pos = func(self, *args, **kwargs)
39
            self.X_sample.append(pos)
40
            return pos
41
42
        return wrapper
43
44
    def _all_possible_pos(self):
45
        pos_space = []
46
        for dim_ in self.conv.max_positions:
47
            pos_space.append(np.arange(dim_))
48
49
        n_dim = len(pos_space)
50
        return np.array(np.meshgrid(*pos_space)).T.reshape(-1, n_dim)
51
52
    @track_X_sample
53
    def init_pos(self, pos):
54
        super().init_pos(pos)
55
        return pos
56