Passed
Push — master ( e7a955...cd6747 )
by Simon
04:58
created

SMBO.track_X_sample()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 1
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
from ..base_optimizer import BaseOptimizer
7
from ...search import Search
8
from .sampling import InitialSampler
9
10
import numpy as np
11
12
np.seterr(divide="ignore", invalid="ignore")
13
14
15
class SMBO(BaseOptimizer, Search):
16
    def __init__(
17
        self,
18
        *args,
19
        warm_start_smbo=None,
20
        max_sample_size=10000000,
21
        sampling={"random": 1000000},
22
        # warnings={"training": 100000, "prediction": 100000000},
23
        **kwargs
24
    ):
25
        super().__init__(*args, **kwargs)
26
        self.warm_start_smbo = warm_start_smbo
27
        self.sampling = sampling
28
        # self.warnings = warnings
29
30
        self.sampler = InitialSampler(self.conv, max_sample_size)
31
32
        # if self.warnings:
33
        #     self.memory_warning(max_sample_size)
34
35
        self.X_sample = []
36
        self.Y_sample = []
37
38
    def init_warm_start_smbo(self):
39
        if self.warm_start_smbo is not None:
40
            # filter out nan and inf
41
            warm_start_smbo = self.warm_start_smbo[
42
                ~self.warm_start_smbo.isin([np.nan, np.inf, -np.inf]).any(1)
43
            ]
44
45
            # filter out elements that are not in search space
46
            int_idx_list = []
47
            for para_name in self.conv.para_names:
48
                search_data_dim = warm_start_smbo[para_name].values
49
                search_space_dim = self.conv.search_space[para_name]
50
51
                int_idx = np.nonzero(np.in1d(search_data_dim, search_space_dim))[0]
52
                int_idx_list.append(int_idx)
53
54
            intersec = int_idx_list[0]
55
            for int_idx in int_idx_list[1:]:
56
                intersec = np.intersect1d(intersec, int_idx)
57
            warm_start_smbo_f = warm_start_smbo.iloc[intersec]
58
59
            X_sample_values = warm_start_smbo_f[self.conv.para_names].values
60
            Y_sample = warm_start_smbo_f["score"].values
61
62
            self.X_sample = self.conv.values2positions(X_sample_values)
63
            self.Y_sample = list(Y_sample)
64
65
    def track_X_sample(iterate):
66
        def wrapper(self, *args, **kwargs):
67
            pos = iterate(self, *args, **kwargs)
68
            self.X_sample.append(pos)
69
            return pos
70
71
        return wrapper
72
73
    def track_y_sample(evaluate):
74
        def wrapper(self, score):
75
            evaluate(self, score)
76
77
            if np.isnan(score) or np.isinf(score):
78
                del self.X_sample[-1]
79
            else:
80
                self.Y_sample.append(score)
81
82
        return wrapper
83
84
    def _sampling(self, all_pos_comb):
85
        if self.sampling is False:
86
            return all_pos_comb
87
        elif "random" in self.sampling:
88
            return self.random_sampling(all_pos_comb)
89
90
    def random_sampling(self, pos_comb):
91
        n_samples = self.sampling["random"]
92
        n_pos_comb = pos_comb.shape[0]
93
94
        if n_pos_comb <= n_samples:
95
            return pos_comb
96
        else:
97
            _idx_sample = np.random.choice(n_pos_comb, n_samples, replace=False)
98
            pos_comb_sampled = pos_comb[_idx_sample, :]
99
            return pos_comb_sampled
100
101
    def _all_possible_pos(self):
102
        pos_space = self.sampler.get_pos_space()
103
        # print("pos_space", pos_space)
104
        n_dim = len(pos_space)
105
        return np.array(np.meshgrid(*pos_space)).T.reshape(-1, n_dim)
106
107
    def memory_warning(self, max_sample_size):
108
        if (
109
            self.conv.search_space_size > self.warnings
110
            and max_sample_size > self.warnings
111
        ):
112
            warning_message0 = "\n Warning:"
113
            warning_message1 = (
114
                "\n search space size of "
115
                + str(self.conv.search_space_size)
116
                + " exceeding recommended limit."
117
            )
118
            warning_message3 = "\n Reduce search space size for better performance."
119
            print(warning_message0 + warning_message1 + warning_message3)
120
121
    @track_X_sample
122
    def init_pos(self, pos):
123
        return super().init_pos(pos)
124
125
    @BaseOptimizer.track_nth_iter
126
    @track_X_sample
127
    def iterate(self):
128
        return self._propose_location()
129
130
    @track_y_sample
131
    def evaluate(self, score_new):
132
        self.score_new = score_new
133
134
        self._evaluate_new2current(score_new)
135
        self._evaluate_current2best()
136
137
    def _propose_location(self):
138
        self._training()
139
        exp_imp = self._expected_improvement()
140
141
        index_best = list(exp_imp.argsort()[::-1])
142
        all_pos_comb_sorted = self.pos_comb[index_best]
143
        pos_best = all_pos_comb_sorted[0]
144
145
        return pos_best
146