Total Complexity | 6 |
Total Lines | 49 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | # Author: Simon Blanke |
||
2 | # Email: [email protected] |
||
3 | # License: MIT License |
||
4 | |||
5 | |||
6 | import numpy as np |
||
7 | |||
8 | np.seterr(divide="ignore", invalid="ignore") |
||
9 | |||
10 | from ..base_optimizer import BaseOptimizer |
||
11 | from ...search import Search |
||
12 | |||
13 | |||
14 | class SMBO(BaseOptimizer, Search): |
||
15 | def __init__( |
||
16 | self, search_space, warm_start_sbom=None, |
||
17 | ): |
||
18 | super().__init__(search_space) |
||
19 | self.warm_start_sbom = warm_start_sbom |
||
20 | |||
21 | self.X_sample = [] |
||
22 | self.Y_sample = [] |
||
23 | |||
24 | if self.warm_start_sbom is not None: |
||
25 | (self.X_sample, self.Y_sample) = self.warm_start_sbom |
||
26 | |||
27 | self.all_pos_comb = self._all_possible_pos() |
||
28 | |||
29 | def track_X_sample(func): |
||
30 | def wrapper(self, *args, **kwargs): |
||
31 | pos = func(self, *args, **kwargs) |
||
32 | self.X_sample.append(pos) |
||
33 | return pos |
||
34 | |||
35 | return wrapper |
||
36 | |||
37 | def _all_possible_pos(self): |
||
38 | pos_space = [] |
||
39 | for dim_ in self.max_positions: |
||
40 | pos_space.append(np.arange(dim_)) |
||
41 | |||
42 | n_dim = len(pos_space) |
||
43 | return np.array(np.meshgrid(*pos_space)).T.reshape(-1, n_dim) |
||
44 | |||
45 | @track_X_sample |
||
46 | def init_pos(self, pos): |
||
47 | super().init_pos(pos) |
||
48 | return pos |
||
49 | |||
50 |