Passed
Push — master ( 75a579...d9fcea )
by Simon
01:55 queued 15s
created

tests.test_hyper_gradient_trafo.test_trafo_1()   A

Complexity

Conditions 1

Size

Total Lines 32
Code Lines 27

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 27
nop 3
dl 0
loc 32
rs 9.232
c 0
b 0
f 0
1
import time
2
import pytest
3
import numpy as np
4
import pandas as pd
5
6
from hyperactive import Hyperactive
7
8
9
def objective_function_0(opt):
10
    score = -opt["x1"] * opt["x1"]
11
    return score
12
13
14
search_space_0 = {
15
    "x1": list(np.arange(-5, 6, 1)),
16
}
17
search_space_1 = {
18
    "x1": list(np.arange(0, 6, 1)),
19
}
20
search_space_2 = {
21
    "x1": list(np.arange(-5, 1, 1)),
22
}
23
24
25
search_space_3 = {
26
    "x1": list(np.arange(-1, 1, 0.1)),
27
}
28
search_space_4 = {
29
    "x1": list(np.arange(-1, 0, 0.1)),
30
}
31
search_space_5 = {
32
    "x1": list(np.arange(0, 1, 0.1)),
33
}
34
35
36
search_space_para_0 = [
37
    (search_space_0),
38
    (search_space_1),
39
    (search_space_2),
40
    (search_space_3),
41
    (search_space_4),
42
    (search_space_5),
43
]
44
45
46
@pytest.mark.parametrize("search_space", search_space_para_0)
47
def test_trafo_0(search_space):
48
    hyper = Hyperactive()
49
    hyper.add_search(objective_function_0, search_space, n_iter=25)
50
    hyper.run()
51
52
    for value in hyper.search_data(objective_function_0)["x1"].values:
53
        if value not in search_space["x1"]:
54
            assert False
55
56
57
# ----------------- # Test if memory warm starts do work as intended
58
59
60
from sklearn.datasets import load_breast_cancer
61
from sklearn.model_selection import cross_val_score
62
from sklearn.tree import DecisionTreeClassifier
63
64
data = load_breast_cancer()
65
X, y = data.data, data.target
66
67
68
def objective_function_1(opt):
69
    dtc = DecisionTreeClassifier(min_samples_split=opt["min_samples_split"])
70
    scores = cross_val_score(dtc, X, y, cv=10)
71
    time.sleep(0.1)
72
73
    return scores.mean()
74
75
76
search_space_0 = {
77
    "min_samples_split": list(np.arange(2, 12)),
78
}
79
80
search_space_1 = {
81
    "min_samples_split": list(np.arange(12, 22)),
82
}
83
84
search_space_2 = {
85
    "min_samples_split": list(np.arange(22, 32)),
86
}
87
88
memory_dict = {"min_samples_split": range(2, 12), "score": range(2, 12)}
89
memory_warm_start_0 = pd.DataFrame(memory_dict)
90
91
memory_dict = {"min_samples_split": range(12, 22), "score": range(12, 22)}
92
memory_warm_start_1 = pd.DataFrame(memory_dict)
93
94
memory_dict = {"min_samples_split": range(22, 32), "score": range(22, 32)}
95
memory_warm_start_2 = pd.DataFrame(memory_dict)
96
97
search_space_para_1 = [
98
    (search_space_0, memory_warm_start_0),
99
    (search_space_1, memory_warm_start_1),
100
    (search_space_2, memory_warm_start_2),
101
]
102
103
random_state_para_0 = [
104
    (0),
105
    (1),
106
    (2),
107
    (3),
108
    (4),
109
]
110
111
# ----------------- # Test if wrong memory warm starts do not work as intended
112
""" test is possible in future gfo versions
113
@pytest.mark.parametrize("random_state", random_state_para_0)
114
@pytest.mark.parametrize("search_space, memory_warm_start", search_space_para_1)
115
def test_trafo_1(random_state, search_space, memory_warm_start):
116
    search_space = search_space
117
    memory_warm_start = memory_warm_start
118
119
    c_time_0 = time.perf_counter()
120
    hyper = Hyperactive()
121
    hyper.add_search(
122
        objective_function_1,
123
        search_space,
124
        n_iter=10,
125
        random_state=random_state,
126
        initialize={"random": 1},
127
    )
128
    hyper.run()
129
    d_time_0 = time.perf_counter() - c_time_0
130
131
    c_time_1 = time.perf_counter()
132
    hyper = Hyperactive()
133
    hyper.add_search(
134
        objective_function_1,
135
        search_space,
136
        n_iter=10,
137
        random_state=random_state,
138
        initialize={"random": 1},
139
        memory_warm_start=memory_warm_start,
140
    )
141
    hyper.run()
142
    d_time_1 = time.perf_counter() - c_time_1
143
144
    assert d_time_1 < d_time_0 * 0.5
145
146
147
148
search_space_0 = {
149
    "min_samples_split": list(np.arange(2, 12)),
150
}
151
152
search_space_1 = {
153
    "min_samples_split": list(np.arange(12, 22)),
154
}
155
156
search_space_2 = {
157
    "min_samples_split": list(np.arange(22, 32)),
158
}
159
160
memory_dict = {"min_samples_split": range(12, 22), "score": range(2, 12)}
161
memory_warm_start_0 = pd.DataFrame(memory_dict)
162
163
memory_dict = {"min_samples_split": range(22, 32), "score": range(12, 22)}
164
memory_warm_start_1 = pd.DataFrame(memory_dict)
165
166
memory_dict = {"min_samples_split": range(2, 12), "score": range(22, 32)}
167
memory_warm_start_2 = pd.DataFrame(memory_dict)
168
169
search_space_para_2 = [
170
    (search_space_0, memory_warm_start_0),
171
    (search_space_1, memory_warm_start_1),
172
    (search_space_2, memory_warm_start_2),
173
]
174
175
random_state_para_0 = [
176
    (0),
177
    (1),
178
    (2),
179
    (3),
180
    (4),
181
]
182
183
184
@pytest.mark.parametrize("random_state", random_state_para_0)
185
@pytest.mark.parametrize("search_space, memory_warm_start", search_space_para_2)
186
def test_trafo_2(random_state, search_space, memory_warm_start):
187
    search_space = search_space
188
    memory_warm_start = memory_warm_start
189
190
    c_time_0 = time.perf_counter()
191
    hyper = Hyperactive()
192
    hyper.add_search(
193
        objective_function_1,
194
        search_space,
195
        n_iter=25,
196
        random_state=random_state,
197
        initialize={"random": 1},
198
    )
199
    hyper.run()
200
    d_time_0 = time.perf_counter() - c_time_0
201
202
    c_time_1 = time.perf_counter()
203
    hyper = Hyperactive()
204
    hyper.add_search(
205
        objective_function_1,
206
        search_space,
207
        n_iter=25,
208
        random_state=random_state,
209
        initialize={"random": 1},
210
        memory_warm_start=memory_warm_start,
211
    )
212
    hyper.run()
213
    d_time_1 = time.perf_counter() - c_time_1
214
215
    assert not (d_time_1 < d_time_0 * 0.8)
216
"""
217