Passed
Push — master ( d7387e...d64774 )
by Simon
03:49
created

tests.test_issue_25   A

Complexity

Total Complexity 1

Size/Duplication

Total Lines 53
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 28
dl 0
loc 53
rs 10
c 0
b 0
f 0
wmc 1

1 Function

Rating   Name   Duplication   Size   Complexity  
A test_issue_25() 0 46 1
1
import numpy as np
2
import pandas as pd
3
4
from hyperactive import Hyperactive
5
6
7
def test_issue_25():
8
    # set a path to save the dataframe
9
    path = "./search_data.csv"
10
    search_space = {
11
        "n_neighbors": list(range(1, 50)),
12
    }
13
14
    # get para names from search space + the score
15
    para_names = list(search_space.keys()) + ["score"]
16
17
    # init empty pandas dataframe
18
    search_data = pd.DataFrame(columns=para_names)
19
    search_data.to_csv(path, index=False)
20
21
    def objective_function(para):
22
        # score = random.choice([1.2, 2.3, np.nan])
23
        score = np.nan
24
25
        # you can access the entire dictionary from "para"
26
        parameter_dict = para.para_dict
27
28
        # save the score in the copy of the dictionary
29
        parameter_dict["score"] = score
30
31
        # append parameter dictionary to pandas dataframe
32
        search_data = pd.read_csv(path, na_values="nan")
33
        search_data_new = pd.DataFrame(parameter_dict, columns=para_names, index=[0])
34
        search_data = search_data.append(search_data_new)
35
        search_data.to_csv(path, index=False, na_rep="nan")
36
37
        return score
38
39
    hyper0 = Hyperactive()
40
    hyper0.add_search(objective_function, search_space, n_iter=50)
41
    hyper0.run()
42
43
    search_data_0 = pd.read_csv(path, na_values="nan")
44
    """
45
    the second run should be much faster than before, 
46
    because Hyperactive already knows most parameters/scores
47
    """
48
    hyper1 = Hyperactive()
49
    hyper1.add_search(
50
        objective_function, search_space, n_iter=50, memory_warm_start=search_data_0
51
    )
52
    hyper1.run()
53