Passed
Pull Request — master (#70)
by
unknown
01:25
created

HilbertGridSearchOptimizer.iterate()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 5
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
# gradient_free_optimizers/hilbert_grid_search.py
2
# Author: Simon Blanke
3
# Email: [email protected]
4
# License: MIT License
5
6
import numpy as np
7
from numpy_hilbert_curve import decode
8
from ..base_optimizer import BaseOptimizer
9
10
11
class HilbertGridSearchOptimizer(BaseOptimizer):
12
    def __init__(
13
        self,
14
        search_space,
15
        initialize={"grid": 4, "random": 2, "vertices": 4},
16
        constraints=[],
17
        random_state=None,
18
        rand_rest_p=0,
19
        nth_process=None,
20
        step_size=1,
21
    ):
22
        super().__init__(
23
            search_space=search_space,
24
            initialize=initialize,
25
            constraints=constraints,
26
            random_state=random_state,
27
            rand_rest_p=rand_rest_p,
28
            nth_process=nth_process,
29
        )
30
        self.step_size = step_size
31
        self.Z = 0  # Current Hilbert integer
32
        self.valid_count = 0  # Counter for valid points
33
34
    def hilbert_move(self):
35
        while True:
36
            # Decode the current Hilbert integer to get an nD point
37
            point = decode(np.array([self.Z]), self.conv.n_dim, self.conv.n_dim)[0]
38
            self.Z += 1
39
            # Check if the point is within the grid bounds
40
            if all(point[i] < self.conv.dim_sizes[i] for i in range(self.conv.n_dim)):
41
                self.valid_count += 1
42
                # Take every step_size-th valid point
43
                if self.valid_count % self.step_size == 1:
44
                    return np.array(point)
45
            # Continue if point is out of bounds
46
47
    @BaseOptimizer.track_new_pos
48
    def iterate(self):
49
        pos_new = self.hilbert_move()
50
        pos_new = self.conv2pos(pos_new)
51
        return pos_new
52
53
    @BaseOptimizer.track_new_score
54
    def evaluate(self, score_new):
55
        BaseOptimizer.evaluate(self, score_new)