Passed
Push — master ( 268666...521b2e )
by Simon
01:18
created

SBOM.track_X_sample()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 1
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
import numpy as np
7
8
np.seterr(divide="ignore", invalid="ignore")
9
10
from ..base_optimizer import BaseOptimizer
11
from ...search import Search
12
13
14
def skip_refit_75(i):
15
    if i <= 33:
16
        return 1
17
    return int((i - 33) ** 0.75)
18
19
20
def skip_refit_50(i):
21
    if i <= 33:
22
        return 1
23
    return int((i - 33) ** 0.5)
24
25
26
def skip_refit_25(i):
27
    if i <= 33:
28
        return 1
29
    return int((i - 33) ** 0.25)
30
31
32
def never_skip_refit(i):
33
    return 1
34
35
36
skip_retrain_ = {
37
    "many": skip_refit_75,
38
    "some": skip_refit_50,
39
    "few": skip_refit_25,
40
    "never": never_skip_refit,
41
}
42
43
44
class SBOM(BaseOptimizer, Search):
45
    def __init__(
46
        self,
47
        search_space,
48
        max_sample_size=1000000,
49
        warm_start_smbo=None,
50
        skip_retrain="never",
51
    ):
52
        super().__init__(search_space)
53
54
        self.max_sample_size = max_sample_size
55
        self.warm_start_smbo = warm_start_smbo
56
        self.skip_retrain = skip_retrain_[skip_retrain]
57
58
        self.X_sample = []
59
        self.Y_sample = []
60
61
        self._all_possible_pos()
62
63
        if self.warm_start_smbo is not None:
64
            (self.X_sample, self.Y_sample) = self.warm_start_smbo
65
66
    def track_X_sample(func):
67
        def wrapper(self, *args, **kwargs):
68
            pos = func(self, *args, **kwargs)
69
            self.X_sample.append(pos)
70
            return pos
71
72
        return wrapper
73
74
    def get_random_sample(self):
75
        sample_size = self._sample_size()
76
        if sample_size > self.all_pos_comb.shape[0]:
77
            sample_size = self.all_pos_comb.shape[0]
78
79
        row_sample = np.random.choice(
80
            self.all_pos_comb.shape[0], size=(sample_size,), replace=False
81
        )
82
        return self.all_pos_comb[row_sample]
83
84
    def _sample_size(self):
85
        n = self.max_sample_size
86
        return int(n * np.tanh(self.all_pos_comb.size / n))
87
88
    def _all_possible_pos(self):
89
        pos_space = []
90
        for dim_ in self.space_dim_size:
91
            pos_space.append(np.arange(dim_))
92
93
        self.n_dim = len(pos_space)
94
        self.all_pos_comb = np.array(np.meshgrid(*pos_space)).T.reshape(
95
            -1, self.n_dim
96
        )
97
98
    @track_X_sample
99
    def init_pos(self, pos):
100
        super().init_pos(pos)
101
        return pos
102
103