Passed
Pull Request — master (#101)
by Simon
01:31
created

tests.test_api.test_callbacks.test_callback_0()   A

Complexity

Conditions 1

Size

Total Lines 24
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 18
nop 0
dl 0
loc 24
rs 9.5
c 0
b 0
f 0
1
import copy
2
import pytest
3
import numpy as np
4
import pandas as pd
5
6
from hyperactive.optimizers import HillClimbingOptimizer
7
from hyperactive.experiment import BaseExperiment, add_callback
8
from hyperactive.search_config import SearchConfig
9
10
11
search_config = SearchConfig(
12
    x0=list(np.arange(-10, 10, 1)),
13
)
14
15
16
def test_callback_0():
17
    class Experiment(BaseExperiment):
18
        def callback_1(self, access):
19
            access.stuff1 = 1
20
21
        def callback_2(self, access):
22
            access.stuff2 = 2
23
24
        @add_callback(before=[callback_1, callback_2])
25
        def objective_function(self, access):
26
            assert access.stuff1 == 1
27
            assert access.stuff2 == 2
28
29
            return 0
30
31
    experiment = Experiment()
32
33
    hyper = HillClimbingOptimizer()
34
    hyper.add_search(
35
        experiment,
36
        search_config,
37
        n_iter=20,
38
    )
39
    hyper.run()
40
41
42
def test_callback_1():
43
    class Experiment(BaseExperiment):
44
        def callback_1(self, access):
45
            access.stuff1 = 1
46
47
        def callback_2(self, access):
48
            access.stuff1 = 2
49
50
        @add_callback(before=[callback_1], after=[callback_2])
51
        def objective_function(self, access):
52
            assert access.stuff1 == 1
53
54
            return 0
55
56
    experiment = Experiment()
57
58
    hyper = HillClimbingOptimizer()
59
    hyper.add_search(
60
        experiment,
61
        search_config,
62
        n_iter=100,
63
    )
64
    hyper.run()
65
66
67
def test_callback_2():
68
    class Experiment(BaseExperiment):
69
70
        def callback_1(self, access):
71
            self.test_var = 1
72
73
        def setup(self, test_var):
74
            self.test_var = test_var
75
76
        @add_callback(before=[callback_1])
77
        def objective_function(self, access):
78
            assert self.test_var == 1
79
80
            return 0
81
82
    experiment = Experiment()
83
    experiment.setup(5)
84
85
    hyper = HillClimbingOptimizer()
86
    hyper.add_search(
87
        experiment,
88
        search_config,
89
        n_iter=100,
90
    )
91
    hyper.run()
92
93
94
def test_callback_3():
95
    class Experiment(BaseExperiment):
96
97
        def callback_1(self, access):
98
            access.pass_through["stuff1"] = 1
99
100
        def setup(self, test_var):
101
            self.test_var = test_var
102
103
        @add_callback(after=[callback_1])
104
        def objective_function(self, access):
105
            if access.nth_iter == 0:
106
                assert self.test_var == 0
107
            else:
108
                assert self.test_var == 1
109
110
            return 0
111
112
    experiment = Experiment()
113
    experiment.setup(0)
114
115
    hyper = HillClimbingOptimizer()
116
    hyper.add_search(
117
        experiment,
118
        search_config,
119
        n_iter=100,
120
    )
121
    hyper.run()
122