Passed
Push — master ( 749acb...38381f )
by Simon
01:38
created

search_best_tree_ensemble.warn()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
# disables sklearn warnings
2
def warn(*args, **kwargs):
3
    pass
4
import warnings
5
warnings.warn = warn
6
7
from sklearn.datasets import load_breast_cancer
8
from sklearn.model_selection import cross_val_score
9
10
from sklearn.ensemble import GradientBoostingClassifier
11
from sklearn.ensemble import RandomForestClassifier
12
from sklearn.ensemble import ExtraTreesClassifier
13
from xgboost import XGBClassifier
14
15
from hyperactive import Hyperactive
16
17
data = load_breast_cancer()
18
X, y = data.data, data.target
19
20
21
def GradientBoostingClassifier_(para, X, y):
22
    gbc = GradientBoostingClassifier(
23
        n_estimators=para["n_estimators"], max_depth=para["max_depth"], min_samples_split=para["min_samples_split"], min_samples_leaf=para["min_samples_leaf"]
24
    )
25
    scores = cross_val_score(gbc, X, y, cv=3)
26
27
    return scores.mean()
28
29
def RandomForestClassifier_(para, X, y):
30
    rfc = RandomForestClassifier(
31
        n_estimators=para["n_estimators"], max_depth=para["max_depth"], min_samples_split=para["min_samples_split"], min_samples_leaf=para["min_samples_leaf"]
32
    )
33
    scores = cross_val_score(rfc, X, y, cv=3)
34
35
    return scores.mean()
36
37
def ExtraTreesClassifier_(para, X, y):
38
    etc = ExtraTreesClassifier(
39
        n_estimators=para["n_estimators"], max_depth=para["max_depth"], min_samples_split=para["min_samples_split"], min_samples_leaf=para["min_samples_leaf"]
40
    )
41
    scores = cross_val_score(etc, X, y, cv=3)
42
43
    return scores.mean()
44
45
def XGBoost_(para, X, y):
46
    etc = XGBClassifier(
47
        n_estimators=para["n_estimators"], max_depth=para["max_depth"]
48
    )
49
    scores = cross_val_score(etc, X, y, cv=3)
50
51
    return scores.mean()
52
53
54
search_config = {
55
    GradientBoostingClassifier_: {
56
        "n_estimators": range(50, 300, 5),
57
        "max_depth": range(2, 10),
58
        "min_samples_split": range(2, 20),
59
        "min_samples_leaf": range(2, 20),
60
    },
61
    RandomForestClassifier_: {
62
        "n_estimators": range(5, 100, 1),
63
        "max_depth": range(2, 20),
64
        "min_samples_split": range(2, 20),
65
        "min_samples_leaf": range(2, 20),
66
    },
67
    ExtraTreesClassifier_: {
68
        "n_estimators": range(50, 300, 5),
69
        "max_depth": range(2, 20),
70
        "min_samples_split": range(2, 20),
71
        "min_samples_leaf": range(2, 20),
72
    },
73
    XGBoost_: {
74
        "n_estimators": range(50, 300, 5),
75
        "max_depth": range(2, 20),
76
    },
77
78
}
79
80
81
opt = Hyperactive(search_config, n_jobs=4, n_iter=100)
82
opt.search(X, y)
83