Passed
Pull Request — main (#756)
by
unknown
02:07 queued 30s
created

test.unit.test_config_parser.test_save()   A

Complexity

Conditions 5

Size

Total Lines 16
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 10
dl 0
loc 16
rs 9.3333
c 0
b 0
f 0
cc 5
nop 0
1
"""
2
Tests functions in config/parser.py
3
"""
4
5
import os
6
7
import pytest
8
import yaml
9
from testfixtures import TempDirectory
10
11
from deepreg.config.parser import (
12
    config_sanity_check,
13
    has_wandb_callback,
14
    load_configs,
15
    save,
16
    update_nested_dict,
17
)
18
19
20
def test_update_nested_dict():
21
    """test update_nested_dict by checking outputs values"""
22
    # two simple dicts with different keys
23
    d = dict(d=1)
24
    v = dict(v=0)
25
    got = update_nested_dict(d, v)
26
    expected = dict(d=1, v=0)
27
    assert got == expected
28
29
    # two simple dicts with same key
30
    d = dict(d=1)
31
    v = dict(d=0)
32
    got = update_nested_dict(d, v)
33
    expected = dict(d=0)
34
    assert got == expected
35
36
    # dict with nested dict without common key
37
    d = dict(d=1)
38
    v = dict(v=dict(x=0))
39
    got = update_nested_dict(d, v)
40
    expected = dict(d=1, v=dict(x=0))
41
    assert got == expected
42
43
    # dict with nested dict with common key
44
    # fail because can not use dict to overwrite non dict values
45
    d = dict(v=1)
46
    v = dict(v=dict(x=0))
47
    with pytest.raises(TypeError) as err_info:
48
        update_nested_dict(d, v)
49
    assert "'int' object does not support item assignment" in str(err_info.value)
50
51
    # dict with nested dict with common key
52
    # pass because can use non dict to overwrite dict
53
    d = dict(v=dict(x=0))
54
    v = dict(v=1)
55
    got = update_nested_dict(d, v)
56
    expected = dict(v=1)
57
    assert got == expected
58
59
    # dict with nested dict with common key
60
    # overwrite a value
61
    d = dict(v=dict(x=0, y=1))
62
    v = dict(v=dict(x=1))
63
    got = update_nested_dict(d, v)
64
    expected = dict(v=dict(x=1, y=1))
65
    assert got == expected
66
67
    # dict with nested dict with common key
68
    # add a value
69
    d = dict(v=dict(x=0, y=1))
70
    v = dict(v=dict(z=1))
71
    got = update_nested_dict(d, v)
72
    expected = dict(v=dict(x=0, y=1, z=1))
73
    assert got == expected
74
75
76
class TestLoadConfigs:
77
    def test_single_config(self):
78
        with open("config/unpaired_labeled_ddf.yaml") as file:
79
            expected = yaml.load(file, Loader=yaml.FullLoader)
80
        got = load_configs("config/unpaired_labeled_ddf.yaml")
81
        assert got == expected
82
83
    def test_multiple_configs(self):
84
        with open("config/unpaired_labeled_ddf.yaml") as file:
85
            expected = yaml.load(file, Loader=yaml.FullLoader)
86
        got = load_configs(
87
            config_path=[
88
                "config/test/ddf.yaml",
89
                "config/test/unpaired_nifti.yaml",
90
                "config/test/labeled.yaml",
91
            ]
92
        )
93
        assert got == expected
94
95
    def test_outdated_config(self):
96
        with open("demos/grouped_mr_heart/grouped_mr_heart.yaml") as file:
97
            expected = yaml.load(file, Loader=yaml.FullLoader)
98
        got = load_configs("config/test/grouped_mr_heart_v011.yaml")
99
        assert got == expected
100
        updated_file_path = "config/test/updated_grouped_mr_heart_v011.yaml"
101
        assert os.path.isfile(updated_file_path)
102
        os.remove(updated_file_path)
103
104
105
def test_save():
106
    """test save by check error and existance of file"""
107
    # default file name
108
    with TempDirectory() as tempdir:
109
        save(config=dict(x=1), out_dir=tempdir.path)
110
        assert os.path.exists(os.path.join(tempdir.path, "config.yaml"))
111
112
    # custom file name
113
    with TempDirectory() as tempdir:
114
        save(config=dict(x=1), out_dir=tempdir.path, filename="test.yaml")
115
        assert os.path.exists(os.path.join(tempdir.path, "test.yaml"))
116
117
    # non yaml filename
118
    with TempDirectory() as tempdir:
119
        with pytest.raises(AssertionError):
120
            save(config=dict(x=1), out_dir=tempdir.path, filename="test.txt")
121
122
123
class TestConfigSanityCheck:
124
    def test_cond_err(self):
125
        """Test error message for conditional model."""
126
        with pytest.raises(ValueError) as err_info:
127
            config_sanity_check(
128
                config=dict(
129
                    dataset=dict(
130
                        type="paired",
131
                        format="h5",
132
                        dir=dict(train=None, valid=None, test=None),
133
                        labeled=False,
134
                    ),
135
                    train=dict(
136
                        method="conditional",
137
                        loss=dict(),
138
                        preprocess=dict(),
139
                        optimizer=dict(name="Adam"),
140
                    ),
141
                )
142
            )
143
        assert (
144
            "For conditional model, data have to be labeled, got unlabeled data."
145
            in str(err_info.value)
146
        )
147
<<<<<<< HEAD
148
=======
149
    assert "data_dir for mode train must be string or list of strings" in str(
150
        err_info.value
151
    )
152
153
    # use unlabeled data for conditional model
154
    with pytest.raises(ValueError) as err_info:
155
        config_sanity_check(
156
            config=dict(
157
                dataset=dict(
158
                    type="paired",
159
                    format="h5",
160
                    dir=dict(train=None, valid=None, test=None),
161
                    labeled=False,
162
                ),
163
                train=dict(
164
                    method="conditional",
165
                    loss=dict(),
166
                    preprocess=dict(),
167
                    optimizer=dict(name="Adam"),
168
                ),
169
            )
170
        )
171
    assert "For conditional model, data have to be labeled, got unlabeled data." in str(
172
        err_info.value
173
    )
174
175
    # check warnings
176
    # train/valid/test of dir is None
177
    # all loss weight <= 0
178
    caplog.clear()  # clear previous log
179
    config_sanity_check(
180
        config=dict(
181
            dataset=dict(
182
                type="paired",
183
                format="h5",
184
                dir=dict(train=None, valid=None, test=None),
185
                labeled=False,
186
            ),
187
            train=dict(
188
                method="ddf",
189
                loss=dict(
190
                    image=dict(name="lncc", weight=0.0),
191
                    label=dict(name="ssd", weight=0.0),
192
                    regularization=dict(name="bending", weight=0.0),
193
                ),
194
                preprocess=dict(),
195
                optimizer=dict(name="Adam"),
196
            ),
197
        )
198
    )
199
    # warning messages can be detected together
200
    assert "Data directory for train is not defined." in caplog.text
201
    assert "Data directory for valid is not defined." in caplog.text
202
    assert "Data directory for test is not defined." in caplog.text
203
204
205
@pytest.mark.parametrize(
206
    """test_dict, expect""", [[{"wandb": True}, True], [{"random": False}, False]]
207
)
208
def test_has_wandb_callback(test_dict, expect):
209
    """
210
    Testing whether function returns expected value
211
    from has_wandb_callback
212
    """
213
    get = has_wandb_callback(test_dict)
214
    assert get == expect
215
>>>>>>> 6448de8c2388756a84bd8c33fa35b3176f09c1c7
216