Passed
Push — master ( 502a6e...660aa4 )
by Simon
04:50 queued 14s
created

gradient_free_optimizers.optimizers.grid.grid_search   A

Complexity

Total Complexity 5

Size/Duplication

Total Lines 43
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 28
dl 0
loc 43
rs 10
c 0
b 0
f 0
wmc 5

3 Methods

Rating   Name   Duplication   Size   Complexity  
A GridSearchOptimizer.iterate() 0 3 1
A GridSearchOptimizer.__init__() 0 17 3
A GridSearchOptimizer.evaluate() 0 3 1
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
from ..base_optimizer import BaseOptimizer
6
from .diagonal_grid_search import DiagonalGridSearchOptimizer
7
from .orthogonal_grid_search import OrthogonalGridSearchOptimizer
8
9
10
class GridSearchOptimizer(BaseOptimizer):
11
    name = "Grid Search"
12
    _name_ = "grid_search"
13
    __name__ = "GridSearchOptimizer"
14
15
    optimizer_type = "global"
16
    computationally_expensive = False
17
18
    def __init__(self, *args, step_size=1, direction="diagonal", **kwargs):
19
        super().__init__(*args, **kwargs)
20
21
        self.step_size = step_size
22
        self.direction = direction
23
24
        if direction == "orthogonal":
25
            self.grid_search_opt = OrthogonalGridSearchOptimizer(
26
                *args, step_size=step_size, **kwargs
27
            )
28
        elif direction == "diagonal":
29
            self.grid_search_opt = DiagonalGridSearchOptimizer(
30
                *args, step_size=step_size, **kwargs
31
            )
32
        else:
33
            msg = ""
34
            raise Exception(msg)
35
36
    @BaseOptimizer.track_new_pos
37
    def iterate(self):
38
        return self.grid_search_opt.iterate()
39
40
    @BaseOptimizer.track_new_score
41
    def evaluate(self, score_new):
42
        self.grid_search_opt.evaluate(score_new)
43