Total Complexity | 5 |
Total Lines | 100 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | import copy |
||
2 | import pytest |
||
3 | import numpy as np |
||
4 | import pandas as pd |
||
5 | |||
6 | from hyperactive import Hyperactive |
||
7 | |||
8 | |||
9 | search_space = { |
||
10 | "x1": list(np.arange(-100, 100, 1)), |
||
11 | } |
||
12 | |||
13 | |||
14 | def test_callback_0(): |
||
15 | def callback_1(access): |
||
16 | access.stuff1 = 1 |
||
17 | |||
18 | def callback_2(access): |
||
19 | access.stuff2 = 2 |
||
20 | |||
21 | def objective_function(access): |
||
22 | assert access.stuff1 == 1 |
||
23 | assert access.stuff2 == 2 |
||
24 | |||
25 | return 0 |
||
26 | |||
27 | hyper = Hyperactive() |
||
28 | hyper.add_search( |
||
29 | objective_function, |
||
30 | search_space, |
||
31 | n_iter=100, |
||
32 | callbacks={"before": [callback_1, callback_2]}, |
||
33 | ) |
||
34 | hyper.run() |
||
35 | |||
36 | |||
37 | def test_callback_1(): |
||
38 | def callback_1(access): |
||
39 | access.stuff1 = 1 |
||
40 | |||
41 | def callback_2(access): |
||
42 | access.stuff1 = 2 |
||
43 | |||
44 | def objective_function(access): |
||
45 | assert access.stuff1 == 1 |
||
46 | |||
47 | return 0 |
||
48 | |||
49 | hyper = Hyperactive() |
||
50 | hyper.add_search( |
||
51 | objective_function, |
||
52 | search_space, |
||
53 | n_iter=100, |
||
54 | callbacks={"before": [callback_1], "after": [callback_2]}, |
||
55 | ) |
||
56 | hyper.run() |
||
57 | |||
58 | |||
59 | def test_callback_2(): |
||
60 | def callback_1(access): |
||
61 | access.pass_through["stuff1"] = 1 |
||
62 | |||
63 | def objective_function(access): |
||
64 | assert access.pass_through["stuff1"] == 1 |
||
65 | |||
66 | return 0 |
||
67 | |||
68 | hyper = Hyperactive() |
||
69 | hyper.add_search( |
||
70 | objective_function, |
||
71 | search_space, |
||
72 | n_iter=100, |
||
73 | callbacks={"before": [callback_1]}, |
||
74 | pass_through={"stuff1": 0}, |
||
75 | ) |
||
76 | hyper.run() |
||
77 | |||
78 | |||
79 | def test_callback_3(): |
||
80 | def callback_1(access): |
||
81 | access.pass_through["stuff1"] = 1 |
||
82 | |||
83 | def objective_function(access): |
||
84 | if access.nth_iter == 0: |
||
85 | assert access.pass_through["stuff1"] == 0 |
||
86 | else: |
||
87 | assert access.pass_through["stuff1"] == 1 |
||
88 | |||
89 | return 0 |
||
90 | |||
91 | hyper = Hyperactive() |
||
92 | hyper.add_search( |
||
93 | objective_function, |
||
94 | search_space, |
||
95 | n_iter=100, |
||
96 | callbacks={"after": [callback_1]}, |
||
97 | pass_through={"stuff1": 0}, |
||
98 | ) |
||
99 | hyper.run() |
||
100 |