Passed
Push — master ( b8f359...927840 )
by Simon
03:10
created

_test_long_term_memory   A

Complexity

Total Complexity 15

Size/Duplication

Total Lines 235
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 152
dl 0
loc 235
rs 10
c 0
b 0
f 0
wmc 15

10 Functions

Rating   Name   Duplication   Size   Complexity  
A func3() 0 2 1
A keras_model() 0 2 1
A test_ltm_1() 0 39 1
A func1() 0 2 1
A objective_function() 0 3 1
A test_ltm_0() 0 30 1
A func2() 0 2 1
A compare_obj() 0 7 3
A model() 0 6 1
A compare_0() 0 2 1

3 Methods

Rating   Name   Duplication   Size   Complexity  
A class1_.__init__() 0 2 1
A class2_.__init__() 0 2 1
A class3_.__init__() 0 2 1
1
import os
2
import inspect
3
import pytest
4
5
from sklearn.datasets import load_iris
6
from sklearn.neighbors import KNeighborsClassifier
7
from sklearn.model_selection import cross_val_score
8
9
import numpy as np
10
import pandas as pd
11
12
from hyperactive import Hyperactive, LongTermMemory
13
14
data = load_iris()
15
X, y = data.data, data.target
16
17
18
def func1():
19
    pass
20
21
22
def func2():
23
    pass
24
25
26
def func3():
27
    pass
28
29
30
class class1:
31
    pass
32
33
34
class class2:
35
    pass
36
37
38
class class3:
39
    pass
40
41
42
class class1_:
43
    def __init__(self):
44
        pass
45
46
47
class class2_:
48
    def __init__(self):
49
        pass
50
51
52
class class3_:
53
    def __init__(self):
54
        pass
55
56
57
search_space_int0 = {
58
    "x1": list(range(2, 30, 1)),
59
}
60
61
search_space_int1 = {
62
    "x1": list(range(2, 30, 1)),
63
    "x2": list(range(0, 101, 1)),
64
}
65
66
search_space_int2 = {
67
    "x1": list(range(2, 30, 1)),
68
    "x2": list(range(-100, 1, 1)),
69
}
70
71
search_space_float = {
72
    "x1": list(range(2, 30, 1)),
73
    "x2": list(np.arange(0, 0.003, 0.001)),
74
}
75
76
search_space_str = {
77
    "x1": list(range(2, 30, 1)),
78
    "x2": ["0", "1", "2"],
79
}
80
81
search_space_func = {
82
    "x1": list(range(2, 30, 1)),
83
    "x2": [func1, func2, func3],
84
}
85
86
87
search_space_class = {
88
    "x1": list(range(2, 30, 1)),
89
    "x2": [class1, class2, class3],
90
}
91
92
93
search_space_obj = {
94
    "x1": list(range(2, 30, 1)),
95
    "x2": [class1_(), class2_(), class3_()],
96
}
97
98
search_space_lists = {
99
    "x1": list(range(2, 30, 1)),
100
    "x2": [[1, 1, 1], [1, 2, 1], [1, 1, 2]],
101
}
102
103
104
def objective_function(opt):
105
    score = -opt["x1"] * opt["x1"]
106
    return score
107
108
109
def model(para):
110
    knr = KNeighborsClassifier(n_neighbors=para["x1"])
111
    scores = cross_val_score(knr, X, y, cv=2)
112
    score = scores.mean()
113
114
    return score
115
116
117
def keras_model(para):
118
    pass
119
120
121
def compare_0(results1, results2):
122
    assert results1.equals(results2)
123
124
125
def compare_obj(results1, results2):
126
    obj1_list = list(results1["x2"].values)
127
    obj2_list = list(results1["x2"].values)
128
129
    for obj1, obj2 in zip(obj1_list, obj2_list):
130
        if obj1 != obj2:
131
            assert False
132
133
134
search_space_para = (
135
    "search_space",
136
    [
137
        (search_space_int0, compare_0),
138
        (search_space_int1, compare_0),
139
        (search_space_int2, compare_0),
140
        (search_space_float, compare_0),
141
        (search_space_str, compare_0),
142
        (search_space_func, compare_obj),
143
        (search_space_class, compare_obj),
144
        (search_space_obj, compare_obj),
145
        (search_space_lists, compare_obj),
146
    ],
147
)
148
149
path_para = (
150
    "path",
151
    [("."), ("./"), (None), ("./dir/dir/")],
152
)
153
154
155
objective_function_para = (
156
    "objective_function",
157
    [
158
        (objective_function),
159
        (model),
160
    ],
161
)
162
163
164
@pytest.mark.parametrize(*objective_function_para)
165
@pytest.mark.parametrize(*path_para)
166
@pytest.mark.parametrize(*search_space_para)
167
def test_ltm_0(objective_function, search_space, path):
168
    (search_space, compare) = search_space
169
170
    print("\n objective_function \n", objective_function)
171
    print("\n search_space \n", search_space)
172
    print("\n compare \n", compare)
173
    print("\n path \n", path)
174
175
    model_name = str(objective_function.__name__)
176
177
    hyper = Hyperactive()
178
    hyper.add_search(
179
        objective_function, search_space, n_iter=10, initialize={"random": 1}
180
    )
181
    hyper.run()
182
    results1 = hyper.results(objective_function)
183
184
    memory = LongTermMemory(model_name, path=path)
185
    memory.save(results1, objective_function)
186
    results2 = memory.load()
187
188
    print("\n results1 \n", results1)
189
    print("\n results2 \n", results2)
190
191
    memory.remove_model_data()
192
193
    compare(results1, results2)
194
195
196
@pytest.mark.parametrize(*objective_function_para)
197
@pytest.mark.parametrize(*path_para)
198
@pytest.mark.parametrize(*search_space_para)
199
def test_ltm_1(objective_function, search_space, path):
200
    (search_space, compare) = search_space
201
202
    print("\n objective_function \n", objective_function)
203
    print("\n search_space \n", search_space)
204
    print("\n compare \n", compare)
205
    print("\n path \n", path)
206
207
    model_name = str(objective_function.__name__)
208
    memory = LongTermMemory(model_name, path=path)
209
210
    hyper1 = Hyperactive()
211
    hyper1.add_search(
212
        objective_function,
213
        search_space,
214
        n_iter=10,
215
        initialize={"random": 1},
216
        long_term_memory=memory,
217
    )
218
    hyper1.run()
219
    results1 = hyper1.results(objective_function)
220
221
    hyper2 = Hyperactive()
222
    hyper2.add_search(
223
        objective_function,
224
        search_space,
225
        n_iter=10,
226
        initialize={"random": 1},
227
        long_term_memory=memory,
228
    )
229
    hyper2.run()
230
    results2 = hyper2.results(objective_function)
231
    memory.remove_model_data()
232
233
    print("\n results1 \n", results1)
234
    print("\n results2 \n", results2)
235