Completed
Push — main ( 183d7f...45ab67 )
by Yunguan
19s queued 13s
created

test_config_sanity_check()   B

Complexity

Conditions 7

Size

Total Lines 94
Code Lines 57

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 57
dl 0
loc 94
rs 7.0072
c 0
b 0
f 0
cc 7
nop 1

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""
2
Tests functions in config/parser.py
3
"""
4
5
import os
6
7
import pytest
8
import yaml
9
from testfixtures import TempDirectory
10
11
from deepreg.config.parser import (
12
    config_sanity_check,
13
    load_configs,
14
    save,
15
    update_nested_dict,
16
)
17
18
19
def test_update_nested_dict():
20
    """test update_nested_dict by checking outputs values"""
21
    # two simple dicts with different keys
22
    d = dict(d=1)
23
    v = dict(v=0)
24
    got = update_nested_dict(d, v)
25
    expected = dict(d=1, v=0)
26
    assert got == expected
27
28
    # two simple dicts with same key
29
    d = dict(d=1)
30
    v = dict(d=0)
31
    got = update_nested_dict(d, v)
32
    expected = dict(d=0)
33
    assert got == expected
34
35
    # dict with nested dict without common key
36
    d = dict(d=1)
37
    v = dict(v=dict(x=0))
38
    got = update_nested_dict(d, v)
39
    expected = dict(d=1, v=dict(x=0))
40
    assert got == expected
41
42
    # dict with nested dict with common key
43
    # fail because can not use dict to overwrite non dict values
44
    d = dict(v=1)
45
    v = dict(v=dict(x=0))
46
    with pytest.raises(TypeError) as err_info:
47
        update_nested_dict(d, v)
48
    assert "'int' object does not support item assignment" in str(err_info.value)
49
50
    # dict with nested dict with common key
51
    # pass because can use non dict to overwrite dict
52
    d = dict(v=dict(x=0))
53
    v = dict(v=1)
54
    got = update_nested_dict(d, v)
55
    expected = dict(v=1)
56
    assert got == expected
57
58
    # dict with nested dict with common key
59
    # overwrite a value
60
    d = dict(v=dict(x=0, y=1))
61
    v = dict(v=dict(x=1))
62
    got = update_nested_dict(d, v)
63
    expected = dict(v=dict(x=1, y=1))
64
    assert got == expected
65
66
    # dict with nested dict with common key
67
    # add a value
68
    d = dict(v=dict(x=0, y=1))
69
    v = dict(v=dict(z=1))
70
    got = update_nested_dict(d, v)
71
    expected = dict(v=dict(x=0, y=1, z=1))
72
    assert got == expected
73
74
75
class TestLoadConfigs:
76
    def test_single_config(self):
77
        with open("config/unpaired_labeled_ddf.yaml") as file:
78
            expected = yaml.load(file, Loader=yaml.FullLoader)
79
        got = load_configs("config/unpaired_labeled_ddf.yaml")
80
        assert got == expected
81
82
    def test_multiple_configs(self):
83
        with open("config/unpaired_labeled_ddf.yaml") as file:
84
            expected = yaml.load(file, Loader=yaml.FullLoader)
85
        got = load_configs(
86
            config_path=[
87
                "config/test/ddf.yaml",
88
                "config/test/unpaired_nifti.yaml",
89
                "config/test/labeled.yaml",
90
            ]
91
        )
92
        assert got == expected
93
94
    def test_outdated_config(self):
95
        with open("demos/grouped_mr_heart/grouped_mr_heart.yaml") as file:
96
            expected = yaml.load(file, Loader=yaml.FullLoader)
97
        got = load_configs("config/test/grouped_mr_heart_v011.yaml")
98
        assert got == expected
99
        updated_file_path = "config/test/updated_grouped_mr_heart_v011.yaml"
100
        assert os.path.isfile(updated_file_path)
101
        os.remove(updated_file_path)
102
103
104
def test_save():
105
    """test save by check error and existance of file"""
106
    # default file name
107
    with TempDirectory() as tempdir:
108
        save(config=dict(x=1), out_dir=tempdir.path)
109
        assert os.path.exists(os.path.join(tempdir.path, "config.yaml"))
110
111
    # custom file name
112
    with TempDirectory() as tempdir:
113
        save(config=dict(x=1), out_dir=tempdir.path, filename="test.yaml")
114
        assert os.path.exists(os.path.join(tempdir.path, "test.yaml"))
115
116
    # non yaml filename
117
    with TempDirectory() as tempdir:
118
        with pytest.raises(AssertionError):
119
            save(config=dict(x=1), out_dir=tempdir.path, filename="test.txt")
120
121
122
def test_config_sanity_check(caplog):
123
    """
124
    Test config_sanity_check by check error messages
125
126
    :param caplog: used to check warning message.
127
    """
128
129
    # unknown data type
130
    with pytest.raises(ValueError) as err_info:
131
        config_sanity_check(config=dict(dataset=dict(type="type")))
132
    assert "data type must be paired / unpaired / grouped" in str(err_info.value)
133
134
    # unknown data format
135
    with pytest.raises(ValueError) as err_info:
136
        config_sanity_check(config=dict(dataset=dict(type="paired", format="format")))
137
    assert "data format must be nifti / h5" in str(err_info.value)
138
139
    # dir is not in data_config
140
    with pytest.raises(AssertionError):
141
        config_sanity_check(config=dict(dataset=dict(type="paired", format="h5")))
142
143
    # dir doesn't have train/valid/test
144
    with pytest.raises(AssertionError):
145
        config_sanity_check(
146
            config=dict(dataset=dict(type="paired", format="h5", dir=dict()))
147
        )
148
149
    # train/valid/test of dir is not string or list of string
150
    with pytest.raises(ValueError) as err_info:
151
        config_sanity_check(
152
            config=dict(
153
                dataset=dict(
154
                    type="paired",
155
                    format="h5",
156
                    dir=dict(train=1, valid=None, test=None),
157
                    labeled=True,
158
                ),
159
                train=dict(model=dict(method="ddf")),
160
            )
161
        )
162
    assert "data_dir for mode train must be string or list of strings" in str(
163
        err_info.value
164
    )
165
166
    # use unlabeled data for conditional model
167
    with pytest.raises(ValueError) as err_info:
168
        config_sanity_check(
169
            config=dict(
170
                dataset=dict(
171
                    type="paired",
172
                    format="h5",
173
                    dir=dict(train=None, valid=None, test=None),
174
                    labeled=False,
175
                ),
176
                train=dict(
177
                    method="conditional",
178
                    loss=dict(),
179
                    preprocess=dict(),
180
                    optimizer=dict(name="Adam"),
181
                ),
182
            )
183
        )
184
    assert "For conditional model, data have to be labeled, got unlabeled data." in str(
185
        err_info.value
186
    )
187
188
    # check warnings
189
    # train/valid/test of dir is None
190
    # all loss weight <= 0
191
    caplog.clear()  # clear previous log
192
    config_sanity_check(
193
        config=dict(
194
            dataset=dict(
195
                type="paired",
196
                format="h5",
197
                dir=dict(train=None, valid=None, test=None),
198
                labeled=False,
199
            ),
200
            train=dict(
201
                method="ddf",
202
                loss=dict(
203
                    image=dict(name="lncc", weight=0.0),
204
                    label=dict(name="ssd", weight=0.0),
205
                    regularization=dict(name="bending", weight=0.0),
206
                ),
207
                preprocess=dict(),
208
                optimizer=dict(name="Adam"),
209
            ),
210
        )
211
    )
212
    # warning messages can be detected together
213
    assert "Data directory for train is not defined." in caplog.text
214
    assert "Data directory for valid is not defined." in caplog.text
215
    assert "Data directory for test is not defined." in caplog.text
216