Passed
Push — master ( 7ebd63...8bd9a3 )
by Simon
06:11
created

gradient_free_optimizers.optimizers.grid.grid_search   A

Complexity

Total Complexity 5

Size/Duplication

Total Lines 72
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 5
eloc 55
dl 0
loc 72
rs 10
c 0
b 0
f 0

3 Methods

Rating   Name   Duplication   Size   Complexity  
A GridSearchOptimizer.iterate() 0 3 1
B GridSearchOptimizer.__init__() 0 46 3
A GridSearchOptimizer.evaluate() 0 3 1
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
from ..base_optimizer import BaseOptimizer
6
from .diagonal_grid_search import DiagonalGridSearchOptimizer
7
from .orthogonal_grid_search import OrthogonalGridSearchOptimizer
8
9
10
class GridSearchOptimizer(BaseOptimizer):
11
    name = "Grid Search"
12
    _name_ = "grid_search"
13
    __name__ = "GridSearchOptimizer"
14
15
    optimizer_type = "global"
16
    computationally_expensive = False
17
18
    def __init__(
19
        self,
20
        search_space,
21
        initialize={"grid": 4, "random": 2, "vertices": 4},
22
        constraints=[],
23
        random_state=None,
24
        rand_rest_p=0,
25
        nth_process=None,
26
        step_size=1,
27
        direction="diagonal",
28
    ):
29
        super().__init__(
30
            search_space=search_space,
31
            initialize=initialize,
32
            constraints=constraints,
33
            random_state=random_state,
34
            rand_rest_p=rand_rest_p,
35
            nth_process=nth_process,
36
        )
37
38
        self.step_size = step_size
39
        self.direction = direction
40
41
        if direction == "orthogonal":
42
            self.grid_search_opt = OrthogonalGridSearchOptimizer(
43
                search_space=search_space,
44
                initialize=initialize,
45
                constraints=constraints,
46
                random_state=random_state,
47
                rand_rest_p=rand_rest_p,
48
                nth_process=nth_process,
49
                step_size=step_size,
50
            )
51
        elif direction == "diagonal":
52
            self.grid_search_opt = DiagonalGridSearchOptimizer(
53
                search_space=search_space,
54
                initialize=initialize,
55
                constraints=constraints,
56
                random_state=random_state,
57
                rand_rest_p=rand_rest_p,
58
                nth_process=nth_process,
59
                step_size=step_size,
60
            )
61
        else:
62
            msg = ""
63
            raise Exception(msg)
64
65
    @BaseOptimizer.track_new_pos
66
    def iterate(self):
67
        return self.grid_search_opt.iterate()
68
69
    @BaseOptimizer.track_new_score
70
    def evaluate(self, score_new):
71
        self.grid_search_opt.evaluate(score_new)
72