Total Complexity | 5 |
Total Lines | 43 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | # Author: Simon Blanke |
||
2 | # Email: [email protected] |
||
3 | # License: MIT License |
||
4 | |||
5 | from ..base_optimizer import BaseOptimizer |
||
6 | from .diagonal_grid_search import DiagonalGridSearchOptimizer |
||
7 | from .orthogonal_grid_search import OrthogonalGridSearchOptimizer |
||
8 | |||
9 | |||
10 | class GridSearchOptimizer(BaseOptimizer): |
||
11 | name = "Grid Search" |
||
12 | _name_ = "grid_search" |
||
13 | __name__ = "GridSearchOptimizer" |
||
14 | |||
15 | optimizer_type = "global" |
||
16 | computationally_expensive = False |
||
17 | |||
18 | def __init__(self, *args, step_size=1, direction="diagonal", **kwargs): |
||
19 | super().__init__(*args, **kwargs) |
||
20 | |||
21 | self.step_size = step_size |
||
22 | self.direction = direction |
||
23 | |||
24 | if direction == "orthogonal": |
||
25 | self.grid_search_opt = OrthogonalGridSearchOptimizer( |
||
26 | *args, step_size=step_size, **kwargs |
||
27 | ) |
||
28 | elif direction == "diagonal": |
||
29 | self.grid_search_opt = DiagonalGridSearchOptimizer( |
||
30 | *args, step_size=step_size, **kwargs |
||
31 | ) |
||
32 | else: |
||
33 | msg = "" |
||
34 | raise Exception(msg) |
||
35 | |||
36 | @BaseOptimizer.track_new_pos |
||
37 | def iterate(self): |
||
38 | return self.grid_search_opt.iterate() |
||
39 | |||
40 | @BaseOptimizer.track_new_score |
||
41 | def evaluate(self, score_new): |
||
42 | self.grid_search_opt.evaluate(score_new) |
||
43 |