Completed
Push — master ( b4252e...a98e6f )
by Simon
01:25
created

tests.test_memory_helpers.test_split_model_IDs()   A

Complexity

Conditions 1

Size

Total Lines 19
Code Lines 14

Duplication

Lines 19
Ratio 100 %

Importance

Changes 0
Metric Value
cc 1
eloc 14
nop 0
dl 19
loc 19
rs 9.7
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
7
from sklearn.datasets import load_iris
8
from sklearn.model_selection import cross_val_score
9
from sklearn.tree import DecisionTreeClassifier
10
from hyperactive import Hyperactive
11
from hyperactive.memory import (
12
    delete_model,
13
    delete_model_dataset,
14
    connect_model_IDs,
15
    split_model_IDs,
16
)
17
18
data = load_iris()
19
X, y = data.data, data.target
20
21
22
def model(para, X_train, y_train):
23
    model = DecisionTreeClassifier(criterion=para["criterion"])
24
    scores = cross_val_score(model, X_train, y_train, cv=2)
25
26
    return scores.mean()
27
28
29
def model1(para, X_train, y_train):
30
    model = DecisionTreeClassifier(max_depth=para["max_depth"])
31
    scores = cross_val_score(model, X_train, y_train, cv=2)
32
33
    return scores.mean()
34
35
36
def model2(para, X_train, y_train):
37
    model = DecisionTreeClassifier(max_depth=para["max_depth"])
38
    scores = cross_val_score(model, X_train, y_train, cv=2)
39
40
    return scores.mean()
41
42
43
search_config = {model: {"criterion": ["gini"]}}
44
search_config1 = {model1: {"max_depth": range(2, 500)}}
45
search_config2 = {model2: {"max_depth": range(2, 500)}}
46
47
48
def test_delete_model():
49
    delete_model(model)
50
51
    opt = Hyperactive(X, y)
52
    opt.search(search_config)
53
54
    delete_model(model)
55
56
57
def test_delete_model_dataset():
58
    delete_model_dataset(model, X, y)
59
60
    opt = Hyperactive(X, y)
61
    opt.search(search_config)
62
63
    delete_model_dataset(model, X, y)
64
65
66 View Code Duplication
def test_connect_model_IDs():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
67
    delete_model(model1)
68
    delete_model(model2)
69
70
    connect_model_IDs(model1, model2)
71
72
    c_time = time.time()
73
    opt = Hyperactive(X, y, memory="long")
74
    opt.search(search_config1, n_iter=1000)
75
    diff_time_0 = time.time() - c_time
76
77
    c_time = time.time()
78
    opt = Hyperactive(X, y, memory="long")
79
    opt.search(search_config2, n_iter=1000)
80
    diff_time_1 = time.time() - c_time
81
82
    assert diff_time_0 / 2 > diff_time_1
83
84
85 View Code Duplication
def test_split_model_IDs():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
86
    delete_model(model1)
87
    delete_model(model2)
88
89
    connect_model_IDs(model1, model2)
90
91
    split_model_IDs(model1, model2)
92
93
    c_time = time.time()
94
    opt = Hyperactive(X, y, memory="long")
95
    opt.search(search_config1, n_iter=1000)
96
    diff_time_0 = time.time() - c_time
97
98
    c_time = time.time()
99
    opt = Hyperactive(X, y, memory="long")
100
    opt.search(search_config2, n_iter=1000)
101
    diff_time_1 = time.time() - c_time
102
103
    assert diff_time_0 / 2 < diff_time_1
104