Passed
Push — master ( 3c0042...f6b58e )
by Simon
01:25
created

SMBO.init_pos()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 2
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
from ..base_optimizer import BaseOptimizer
7
from ...search import Search
8
9
import numpy as np
10
from itertools import compress
11
12
np.seterr(divide="ignore", invalid="ignore")
13
14
15
class SMBO(BaseOptimizer, Search):
16
    def __init__(
17
        self,
18
        search_space,
19
        initialize={"grid": 4, "random": 2, "vertices": 4},
20
        warm_start_smbo=None,
21
        sampling={"random": 100000},
22
        warnings=100000000,
23
    ):
24
        super().__init__(search_space, initialize)
25
        self.warm_start_smbo = warm_start_smbo
26
        self.sampling = sampling
27
        self.warnings = warnings
28
29
    def init_position_combinations(self):
30
        search_space_size = 1
31
        for value_ in self.conv.search_space.values():
32
            search_space_size *= len(value_)
33
34
        self.X_sample = []
35
        self.Y_sample = []
36
37
        if self.warnings:
38
            self.memory_warning(search_space_size)
39
        self.all_pos_comb = self._all_possible_pos()
40
41
    def init_warm_start_smbo(self):
42
        if self.warm_start_smbo is not None:
43
            X_sample_values = self.warm_start_smbo[self.conv.para_names].values
44
            Y_sample = self.warm_start_smbo["score"].values
45
46
            self.X_sample = self.conv.values2positions(X_sample_values)
47
            self.Y_sample = list(Y_sample)
48
49
            # filter out nan
50
            mask = ~np.isnan(Y_sample)
51
            self.X_sample = list(compress(self.X_sample, mask))
52
            self.Y_sample = list(compress(self.Y_sample, mask))
53
54
    def track_X_sample(func):
55
        def wrapper(self, *args, **kwargs):
56
            pos = func(self, *args, **kwargs)
57
            self.X_sample.append(pos)
58
            return pos
59
60
        return wrapper
61
62
    def random_sampling(self):
63
        n_samples = self.sampling["random"]
64
        n_pos_comb = self.all_pos_comb.shape[0]
65
66
        if n_pos_comb <= n_samples:
67
            return self.all_pos_comb
68
        else:
69
            _idx_sample = np.random.choice(n_pos_comb, n_samples, replace=False)
70
            pos_comb_sampled = self.all_pos_comb[_idx_sample, :]
71
            return pos_comb_sampled
72
73
    def _all_possible_pos(self):
74
        if self.conv.max_dim < 255:
75
            _dtype = np.uint8
76
        elif self.conv.max_dim < 65535:
77
            _dtype = np.uint16
78
        elif self.conv.max_dim < 4294967295:
79
            _dtype = np.uint32
80
        else:
81
            _dtype = np.uint64
82
83
        pos_space = []
84
        for dim_ in self.conv.dim_sizes:
85
            pos_space.append(np.arange(dim_, dtype=_dtype))
86
87
        n_dim = len(pos_space)
88
        return np.array(np.meshgrid(*pos_space)).T.reshape(-1, n_dim)
89
90
    def memory_warning(self, search_space_size):
91
        if search_space_size > self.warnings:
92
            warning_message0 = "\n Warning:"
93
            warning_message1 = (
94
                "\n search space size of "
95
                + str(search_space_size)
96
                + " exceeding recommended limit."
97
            )
98
            warning_message3 = "\n Reduce search space size for better performance."
99
            print(warning_message0 + warning_message1 + warning_message3)
100
101
    @track_X_sample
102
    def init_pos(self, pos):
103
        super().init_pos(pos)
104
        return pos
105