1 | import numpy as np |
||
2 | from sklearn.datasets import load_iris |
||
3 | from sklearn.neighbors import KNeighborsClassifier |
||
4 | from sklearn.model_selection import cross_val_score |
||
5 | |||
6 | from gradient_free_optimizers import RandomSearchOptimizer |
||
7 | |||
8 | |||
9 | def test_function(): |
||
10 | def objective_function(para): |
||
11 | score = -para["x1"] * para["x1"] |
||
12 | return score |
||
13 | |||
14 | search_space = { |
||
15 | "x1": np.arange(-100, 101, 1), |
||
16 | } |
||
17 | |||
18 | opt = RandomSearchOptimizer(search_space) |
||
19 | opt.search(objective_function, n_iter=30) |
||
20 | |||
21 | |||
22 | def test_sklearn(): |
||
23 | data = load_iris() |
||
24 | X, y = data.data, data.target |
||
25 | |||
26 | def model(para): |
||
27 | knr = KNeighborsClassifier(n_neighbors=para["n_neighbors"]) |
||
28 | scores = cross_val_score(knr, X, y, cv=5) |
||
29 | score = scores.mean() |
||
30 | |||
31 | return score |
||
32 | |||
33 | search_space = { |
||
34 | "n_neighbors": np.arange(1, 51, 1), |
||
35 | } |
||
36 | |||
37 | opt = RandomSearchOptimizer(search_space) |
||
38 | opt.search(model, n_iter=30) |
||
39 | |||
40 | |||
41 | View Code Duplication | def test_obj_func_return_dictionary_0(): |
|
0 ignored issues
–
show
Duplication
introduced
by
![]() |
|||
42 | def objective_function(para): |
||
43 | score = -para["x1"] * para["x1"] |
||
44 | return score, {"_x1_": para["x1"]} |
||
45 | |||
46 | search_space = { |
||
47 | "x1": np.arange(-100, 101, 1), |
||
48 | } |
||
49 | |||
50 | opt = RandomSearchOptimizer(search_space) |
||
51 | opt.search(objective_function, n_iter=30) |
||
52 | |||
53 | assert "_x1_" in list(opt.search_data.columns) |
||
54 | |||
55 | |||
56 | View Code Duplication | def test_obj_func_return_dictionary_1(): |
|
0 ignored issues
–
show
|
|||
57 | def objective_function(para): |
||
58 | score = -para["x1"] * para["x1"] |
||
59 | return score, {"_x1_": para["x1"], "_x1_*2": para["x1"] * 2} |
||
60 | |||
61 | search_space = { |
||
62 | "x1": np.arange(-100, 101, 1), |
||
63 | } |
||
64 | |||
65 | opt = RandomSearchOptimizer(search_space) |
||
66 | opt.search(objective_function, n_iter=30) |
||
67 | |||
68 | assert "_x1_" in list(opt.search_data.columns) |
||
69 | assert "_x1_*2" in list(opt.search_data.columns) |
||
70 |