Completed
Push — master ( 77ef8d...d45822 )
by Simon
02:00
created

tests.test_memory_helpers.test_connect_model_IDs()   A

Complexity

Conditions 1

Size

Total Lines 17
Code Lines 13

Duplication

Lines 17
Ratio 100 %

Importance

Changes 0
Metric Value
cc 1
eloc 13
nop 0
dl 17
loc 17
rs 9.75
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
7
from sklearn.datasets import load_iris
8
from sklearn.model_selection import cross_val_score
9
from sklearn.tree import DecisionTreeClassifier
10
from hyperactive import Hyperactive
11
from hyperactive.memory import (
12
    delete_model,
13
    delete_model_dataset,
14
    connect_model_IDs,
15
    split_model_IDs,
16
    get_best_model,
17
    reset_memory,
18
)
19
20
data = load_iris()
21
X, y = data.data, data.target
22
23
24
def model(para, X_train, y_train):
25
    model = DecisionTreeClassifier(criterion=para["criterion"])
26
    scores = cross_val_score(model, X_train, y_train, cv=2)
27
28
    return scores.mean()
29
30
31
def model1(para, X_train, y_train):
32
    model = DecisionTreeClassifier(max_depth=para["max_depth"])
33
    scores = cross_val_score(model, X_train, y_train, cv=2)
34
35
    return scores.mean()
36
37
38
def model2(para, X_train, y_train):
39
    model = DecisionTreeClassifier(max_depth=para["max_depth"])
40
    scores = cross_val_score(model, X_train, y_train, cv=2)
41
42
    return scores.mean()
43
44
45
search_config = {model: {"criterion": ["gini"]}}
46
search_config1 = {model1: {"max_depth": range(2, 500)}}
47
search_config2 = {model2: {"max_depth": range(2, 500)}}
48
49
50
def test_reset_memory():
51
    reset_memory(force_true=True)
52
53
54
def test_delete_model():
55
    delete_model(model)
56
57
    opt = Hyperactive(X, y)
58
    opt.search(search_config)
59
60
    delete_model(model)
61
62
63
def test_delete_model_dataset():
64
    delete_model_dataset(model, X, y)
65
66
    opt = Hyperactive(X, y)
67
    opt.search(search_config)
68
69
    delete_model_dataset(model, X, y)
70
71
72 View Code Duplication
def test_connect_model_IDs():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
73
    delete_model(model1)
74
    delete_model(model2)
75
76
    connect_model_IDs(model1, model2)
77
78
    c_time = time.time()
79
    opt = Hyperactive(X, y, memory="long")
80
    opt.search(search_config1, n_iter=1000)
81
    diff_time_0 = time.time() - c_time
82
83
    c_time = time.time()
84
    opt = Hyperactive(X, y, memory="long")
85
    opt.search(search_config2, n_iter=1000)
86
    diff_time_1 = time.time() - c_time
87
88
    assert diff_time_0 / 2 > diff_time_1
89
90
91 View Code Duplication
def test_split_model_IDs():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
92
    delete_model(model1)
93
    delete_model(model2)
94
95
    connect_model_IDs(model1, model2)
96
97
    split_model_IDs(model1, model2)
98
99
    c_time = time.time()
100
    opt = Hyperactive(X, y, memory="long")
101
    opt.search(search_config1, n_iter=1000)
102
    diff_time_0 = time.time() - c_time
103
104
    c_time = time.time()
105
    opt = Hyperactive(X, y, memory="long")
106
    opt.search(search_config2, n_iter=1000)
107
    diff_time_1 = time.time() - c_time
108
109
    assert diff_time_0 / 2 < diff_time_1
110
111
112
def test_get_best_model():
113
    score, search_config, init_config = get_best_model(X, y)
114
115
    reset_memory(force_true=True)
116