Passed
Pull Request — master (#101)
by Simon
01:28
created

tests.test_api.test_callbacks.test_callback_1()   A

Complexity

Conditions 1

Size

Total Lines 23
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 17
nop 0
dl 0
loc 23
rs 9.55
c 0
b 0
f 0
1
import copy
2
import pytest
3
import numpy as np
4
import pandas as pd
5
6
from hyperactive.optimizers import HillClimbingOptimizer
7
from hyperactive.experiment import BaseExperiment, add_callback
8
from hyperactive.search_config import SearchConfig
9
10
11
search_config = SearchConfig(
12
    x0=list(np.arange(-10, 10, 1)),
13
)
14
15
n_iter = 20
16
17
18
def test_callback_0():
19
    class Experiment(BaseExperiment):
20
        def callback_1(self, access):
21
            access.stuff1 = 1
22
23
        def callback_2(self, access):
24
            access.stuff2 = 2
25
26
        @add_callback(before=[callback_1, callback_2])
27
        def objective_function(self, access):
28
            assert access.stuff1 == 1
29
            assert access.stuff2 == 2
30
31
            return 0
32
33
    experiment = Experiment()
34
35
    hyper = HillClimbingOptimizer()
36
    hyper.add_search(
37
        experiment,
38
        search_config,
39
        n_iter=n_iter,
40
    )
41
    hyper.run()
42
43
44
def test_callback_1():
45
    class Experiment(BaseExperiment):
46
        def callback_1(self, access):
47
            access.stuff1 = 1
48
49
        def callback_2(self, access):
50
            access.stuff1 = 2
51
52
        @add_callback(before=[callback_1], after=[callback_2])
53
        def objective_function(self, access):
54
            assert access.stuff1 == 1
55
56
            return 0
57
58
    experiment = Experiment()
59
60
    hyper = HillClimbingOptimizer()
61
    hyper.add_search(
62
        experiment,
63
        search_config,
64
        n_iter=n_iter,
65
    )
66
    hyper.run()
67
68
69
def test_callback_2():
70
    class Experiment(BaseExperiment):
71
72
        def callback_1(self, access):
73
            self.test_var = 1
74
75
        def setup(self, test_var):
76
            self.test_var = test_var
77
78
        @add_callback(before=[callback_1])
79
        def objective_function(self, access):
80
            assert self.test_var == 1
81
82
            return 0
83
84
    experiment = Experiment()
85
    experiment.setup(5)
86
87
    hyper = HillClimbingOptimizer()
88
    hyper.add_search(
89
        experiment,
90
        search_config,
91
        n_iter=n_iter,
92
    )
93
    hyper.run()
94
95
96
def test_callback_3():
97
    class Experiment(BaseExperiment):
98
99
        def callback_1(self, access):
100
            self.test_var = 1
101
102
        def setup(self, test_var):
103
            self.test_var = test_var
104
105
        @add_callback(after=[callback_1])
106
        def objective_function(self, access):
107
            if access.nth_iter == 0:
108
                assert self.test_var == 0
109
            else:
110
                assert self.test_var == 1
111
112
            return 0
113
114
    experiment = Experiment()
115
    experiment.setup(0)
116
117
    hyper = HillClimbingOptimizer()
118
    hyper.add_search(
119
        experiment,
120
        search_config,
121
        n_iter=n_iter,
122
    )
123
    hyper.run()
124