Total Complexity | 7 |
Total Lines | 56 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | # Author: Simon Blanke |
||
2 | # Email: [email protected] |
||
3 | # License: MIT License |
||
4 | |||
5 | |||
6 | import numpy as np |
||
7 | |||
8 | np.seterr(divide="ignore", invalid="ignore") |
||
9 | |||
10 | from ..base_optimizer import BaseOptimizer |
||
11 | from ...search import Search |
||
12 | |||
13 | |||
14 | class SMBO(BaseOptimizer, Search): |
||
15 | def __init__( |
||
16 | self, |
||
17 | search_space, |
||
18 | initialize={"grid": 4, "random": 2, "vertices": 4}, |
||
19 | warm_start_smbo=None, |
||
20 | ): |
||
21 | super().__init__(search_space, initialize) |
||
22 | self.warm_start_smbo = warm_start_smbo |
||
23 | |||
24 | self.X_sample = [] |
||
25 | self.Y_sample = [] |
||
26 | |||
27 | self.all_pos_comb = self._all_possible_pos() |
||
28 | |||
29 | def init_warm_start_smbo(self): |
||
30 | if self.warm_start_smbo is not None: |
||
31 | (X_sample_values, Y_sample) = self.warm_start_smbo |
||
32 | |||
33 | self.X_sample = self.conv.values2positions(X_sample_values) |
||
34 | self.Y_sample = list(Y_sample) |
||
35 | |||
36 | def track_X_sample(func): |
||
37 | def wrapper(self, *args, **kwargs): |
||
38 | pos = func(self, *args, **kwargs) |
||
39 | self.X_sample.append(pos) |
||
40 | return pos |
||
41 | |||
42 | return wrapper |
||
43 | |||
44 | def _all_possible_pos(self): |
||
45 | pos_space = [] |
||
46 | for dim_ in self.conv.max_positions: |
||
47 | pos_space.append(np.arange(dim_)) |
||
48 | |||
49 | n_dim = len(pos_space) |
||
50 | return np.array(np.meshgrid(*pos_space)).T.reshape(-1, n_dim) |
||
51 | |||
52 | @track_X_sample |
||
53 | def init_pos(self, pos): |
||
54 | super().init_pos(pos) |
||
55 | return pos |
||
56 |