Total Complexity | 5 |
Total Lines | 45 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | # Author: Simon Blanke |
||
2 | # Email: [email protected] |
||
3 | # License: MIT License |
||
4 | |||
5 | import numpy as np |
||
6 | |||
7 | from ..base_optimizer import BaseOptimizer |
||
8 | |||
9 | |||
10 | class OrthogonalGridSearchOptimizer(BaseOptimizer): |
||
11 | def __init__(self, *args, step_size=1, **kwargs): |
||
12 | super().__init__(*args, **kwargs) |
||
13 | |||
14 | self.step_size = step_size |
||
15 | |||
16 | def grid_move(self): |
||
17 | mod_tmp = self.nth_trial * self.step_size + int( |
||
18 | self.nth_trial * self.step_size / self.conv.search_space_size |
||
19 | ) |
||
20 | div_tmp = self.nth_trial * self.step_size + int( |
||
21 | self.nth_trial * self.step_size / self.conv.search_space_size |
||
22 | ) |
||
23 | flipped_new_pos = [] |
||
24 | |||
25 | for dim_size in self.conv.dim_sizes: |
||
26 | mod = mod_tmp % dim_size |
||
27 | div = int(div_tmp / dim_size) |
||
28 | |||
29 | flipped_new_pos.append(mod) |
||
30 | |||
31 | mod_tmp = div |
||
32 | div_tmp = div |
||
33 | |||
34 | return np.array(flipped_new_pos) |
||
35 | |||
36 | @BaseOptimizer.track_new_pos |
||
37 | def iterate(self): |
||
38 | pos_new = self.grid_move() |
||
39 | pos_new = self.conv2pos(pos_new) |
||
40 | return pos_new |
||
41 | |||
42 | @BaseOptimizer.track_new_score |
||
43 | def evaluate(self, score_new): |
||
44 | BaseOptimizer.evaluate(self, score_new) |
||
45 |