Passed
Pull Request — master (#101)
by Simon
01:37
created

tests.test_api.test_catch   A

Complexity

Total Complexity 5

Size/Duplication

Total Lines 138
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 96
dl 0
loc 138
rs 10
c 0
b 0
f 0
wmc 5

5 Functions

Rating   Name   Duplication   Size   Complexity  
A test_catch_3() 0 17 1
A test_catch_all_1() 0 33 1
A test_catch_all_0() 0 29 1
A test_catch_1() 0 17 1
A test_catch_2() 0 17 1
1
import copy
2
import pytest
3
import math
4
import numpy as np
5
import pandas as pd
6
7
from hyperactive.optimizers import HillClimbingOptimizer
8
from hyperactive.experiment import BaseExperiment, add_catch
9
from hyperactive.search_config import SearchConfig
10
11
12
search_config = SearchConfig(
13
    x0=list(np.arange(-10, 10, 1)),
14
)
15
16
17
def test_catch_1():
18
    class Experiment(BaseExperiment):
19
        @add_catch({TypeError: np.nan})
20
        def objective_function(self, access):
21
            a = 1 + "str"
22
23
            return 0
24
25
    experiment = Experiment()
26
27
    hyper = HillClimbingOptimizer()
28
    hyper.add_search(
29
        experiment,
30
        search_config,
31
        n_iter=100,
32
    )
33
    hyper.run()
34
35
36
def test_catch_2():
37
    class Experiment(BaseExperiment):
38
        @add_catch({ValueError: np.nan})
39
        def objective_function(self, access):
40
            math.sqrt(-10)
41
42
            return 0
43
44
    experiment = Experiment()
45
46
    hyper = HillClimbingOptimizer()
47
    hyper.add_search(
48
        experiment,
49
        search_config,
50
        n_iter=100,
51
    )
52
    hyper.run()
53
54
55
def test_catch_3():
56
    class Experiment(BaseExperiment):
57
        @add_catch({ZeroDivisionError: np.nan})
58
        def objective_function(self, access):
59
            x = 1 / 0
60
61
            return 0
62
63
    experiment = Experiment()
64
65
    hyper = HillClimbingOptimizer()
66
    hyper.add_search(
67
        experiment,
68
        search_config,
69
        n_iter=100,
70
    )
71
    hyper.run()
72
73
74
def test_catch_all_0():
75
    class Experiment(BaseExperiment):
76
        @add_catch(
77
            {
78
                TypeError: np.nan,
79
                ValueError: np.nan,
80
                ZeroDivisionError: np.nan,
81
            }
82
        )
83
        def objective_function(self, access):
84
            a = 1 + "str"
85
            math.sqrt(-10)
86
            x = 1 / 0
87
88
            return 0
89
90
    experiment = Experiment()
91
92
    hyper = HillClimbingOptimizer()
93
    hyper.add_search(
94
        experiment,
95
        search_config,
96
        n_iter=100,
97
    )
98
    hyper.run()
99
100
    nan_ = hyper.search_data(experiment)["score"].values[0]
101
102
    assert math.isnan(nan_)
103
104
105
def test_catch_all_1():
106
    catch_return = (np.nan, {"error": True})
107
108
    class Experiment(BaseExperiment):
109
        @add_catch(
110
            {
111
                TypeError: catch_return,
112
                ValueError: catch_return,
113
                ZeroDivisionError: catch_return,
114
            }
115
        )
116
        def objective_function(self, access):
117
            a = 1 + "str"
118
            math.sqrt(-10)
119
            x = 1 / 0
120
121
            return 0, {"error": False}
122
123
    experiment = Experiment()
124
125
    hyper = HillClimbingOptimizer()
126
    hyper.add_search(
127
        experiment,
128
        search_config,
129
        n_iter=100,
130
    )
131
    hyper.run()
132
133
    nan_ = hyper.search_data(experiment)["score"].values[0]
134
    error_ = hyper.search_data(experiment)["error"].values[0]
135
136
    assert math.isnan(nan_)
137
    assert error_ == True
138