Passed
Push — master ( 9fd0ae...3f3c18 )
by Simon
02:02 queued 12s
created

gradient_free_optimizers.optimizers.sequence_model.smbo   A

Complexity

Total Complexity 16

Size/Duplication

Total Lines 102
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 72
dl 0
loc 102
rs 10
c 0
b 0
f 0
wmc 16

9 Methods

Rating   Name   Duplication   Size   Complexity  
A SMBO.random_sampling() 0 10 2
A SMBO.init_pos() 0 4 1
A SMBO.track_X_sample() 0 7 1
A SMBO._all_possible_pos() 0 5 1
A SMBO._sampling() 0 5 3
A SMBO.init_warm_start_smbo() 0 12 2
A SMBO.init_position_combinations() 0 3 1
A SMBO.__init__() 0 18 2
A SMBO.memory_warning() 0 13 3
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
from ..base_optimizer import BaseOptimizer
7
from ...search import Search
8
from .sampling import InitialSampler
9
10
import numpy as np
11
from itertools import compress
12
13
np.seterr(divide="ignore", invalid="ignore")
14
15
16
class SMBO(BaseOptimizer, Search):
17
    def __init__(
18
        self,
19
        search_space,
20
        initialize={"grid": 4, "random": 2, "vertices": 4},
21
        warm_start_smbo=None,
22
        init_sample_size=10000000,
23
        sampling={"random": 1000000},
24
        warnings=100000000,
25
    ):
26
        super().__init__(search_space, initialize)
27
        self.warm_start_smbo = warm_start_smbo
28
        self.sampling = sampling
29
        self.warnings = warnings
30
31
        self.sampler = InitialSampler(self.conv, init_sample_size)
32
33
        if self.warnings:
34
            self.memory_warning(init_sample_size)
35
36
    def init_position_combinations(self):
37
        self.X_sample = []
38
        self.Y_sample = []
39
40
    def init_warm_start_smbo(self):
41
        if self.warm_start_smbo is not None:
42
            # filter out nan and inf
43
            warm_start_smbo = self.warm_start_smbo[
44
                ~self.warm_start_smbo.isin([np.nan, np.inf, -np.inf]).any(1)
45
            ]
46
47
            X_sample_values = warm_start_smbo[self.conv.para_names].values
48
            Y_sample = warm_start_smbo["score"].values
49
50
            self.X_sample = self.conv.values2positions(X_sample_values)
51
            self.Y_sample = list(Y_sample)
52
53
    def track_X_sample(func):
54
        def wrapper(self, *args, **kwargs):
55
            pos = func(self, *args, **kwargs)
56
            self.X_sample.append(pos)
57
            return pos
58
59
        return wrapper
60
61
    def _sampling(self, all_pos_comb):
62
        if self.sampling is False:
63
            return all_pos_comb
64
        elif "random" in self.sampling:
65
            return self.random_sampling(all_pos_comb)
66
67
    def random_sampling(self, pos_comb):
68
        n_samples = self.sampling["random"]
69
        n_pos_comb = pos_comb.shape[0]
70
71
        if n_pos_comb <= n_samples:
72
            return pos_comb
73
        else:
74
            _idx_sample = np.random.choice(n_pos_comb, n_samples, replace=False)
75
            pos_comb_sampled = pos_comb[_idx_sample, :]
76
            return pos_comb_sampled
77
78
    def _all_possible_pos(self):
79
        pos_space = self.sampler.get_pos_space()
80
        # print("pos_space", pos_space)
81
        n_dim = len(pos_space)
82
        return np.array(np.meshgrid(*pos_space)).T.reshape(-1, n_dim)
83
84
    def memory_warning(self, init_sample_size):
85
        if (
86
            self.conv.search_space_size > self.warnings
87
            and init_sample_size > self.warnings
88
        ):
89
            warning_message0 = "\n Warning:"
90
            warning_message1 = (
91
                "\n search space size of "
92
                + str(self.conv.search_space_size)
93
                + " exceeding recommended limit."
94
            )
95
            warning_message3 = "\n Reduce search space size for better performance."
96
            print(warning_message0 + warning_message1 + warning_message3)
97
98
    @track_X_sample
99
    def init_pos(self, pos):
100
        super().init_pos(pos)
101
        return pos
102