Passed
Branch main (46851d)
by Yunguan
02:04
created

TestConfigSanityCheck.test_cond_err()   A

Complexity

Conditions 2

Size

Total Lines 22
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 17
dl 0
loc 22
rs 9.55
c 0
b 0
f 0
cc 2
nop 1
1
"""
2
Tests functions in config/parser.py
3
"""
4
5
import os
6
7
import pytest
8
import yaml
9
from testfixtures import TempDirectory
10
11
from deepreg.config.parser import (
12
    config_sanity_check,
13
    load_configs,
14
    save,
15
    update_nested_dict,
16
)
17
18
19
def test_update_nested_dict():
20
    """test update_nested_dict by checking outputs values"""
21
    # two simple dicts with different keys
22
    d = dict(d=1)
23
    v = dict(v=0)
24
    got = update_nested_dict(d, v)
25
    expected = dict(d=1, v=0)
26
    assert got == expected
27
28
    # two simple dicts with same key
29
    d = dict(d=1)
30
    v = dict(d=0)
31
    got = update_nested_dict(d, v)
32
    expected = dict(d=0)
33
    assert got == expected
34
35
    # dict with nested dict without common key
36
    d = dict(d=1)
37
    v = dict(v=dict(x=0))
38
    got = update_nested_dict(d, v)
39
    expected = dict(d=1, v=dict(x=0))
40
    assert got == expected
41
42
    # dict with nested dict with common key
43
    # fail because can not use dict to overwrite non dict values
44
    d = dict(v=1)
45
    v = dict(v=dict(x=0))
46
    with pytest.raises(TypeError) as err_info:
47
        update_nested_dict(d, v)
48
    assert "'int' object does not support item assignment" in str(err_info.value)
49
50
    # dict with nested dict with common key
51
    # pass because can use non dict to overwrite dict
52
    d = dict(v=dict(x=0))
53
    v = dict(v=1)
54
    got = update_nested_dict(d, v)
55
    expected = dict(v=1)
56
    assert got == expected
57
58
    # dict with nested dict with common key
59
    # overwrite a value
60
    d = dict(v=dict(x=0, y=1))
61
    v = dict(v=dict(x=1))
62
    got = update_nested_dict(d, v)
63
    expected = dict(v=dict(x=1, y=1))
64
    assert got == expected
65
66
    # dict with nested dict with common key
67
    # add a value
68
    d = dict(v=dict(x=0, y=1))
69
    v = dict(v=dict(z=1))
70
    got = update_nested_dict(d, v)
71
    expected = dict(v=dict(x=0, y=1, z=1))
72
    assert got == expected
73
74
75
class TestLoadConfigs:
76
    def test_single_config(self):
77
        with open("config/unpaired_labeled_ddf.yaml") as file:
78
            expected = yaml.load(file, Loader=yaml.FullLoader)
79
        got = load_configs("config/unpaired_labeled_ddf.yaml")
80
        assert got == expected
81
82
    def test_multiple_configs(self):
83
        with open("config/unpaired_labeled_ddf.yaml") as file:
84
            expected = yaml.load(file, Loader=yaml.FullLoader)
85
        got = load_configs(
86
            config_path=[
87
                "config/test/ddf.yaml",
88
                "config/test/unpaired_nifti.yaml",
89
                "config/test/labeled.yaml",
90
            ]
91
        )
92
        assert got == expected
93
94
    def test_outdated_config(self):
95
        with open("demos/grouped_mr_heart/grouped_mr_heart.yaml") as file:
96
            expected = yaml.load(file, Loader=yaml.FullLoader)
97
        got = load_configs("config/test/grouped_mr_heart_v011.yaml")
98
        assert got == expected
99
        updated_file_path = "config/test/updated_grouped_mr_heart_v011.yaml"
100
        assert os.path.isfile(updated_file_path)
101
        os.remove(updated_file_path)
102
103
104
def test_save():
105
    """test save by check error and existance of file"""
106
    # default file name
107
    with TempDirectory() as tempdir:
108
        save(config=dict(x=1), out_dir=tempdir.path)
109
        assert os.path.exists(os.path.join(tempdir.path, "config.yaml"))
110
111
    # custom file name
112
    with TempDirectory() as tempdir:
113
        save(config=dict(x=1), out_dir=tempdir.path, filename="test.yaml")
114
        assert os.path.exists(os.path.join(tempdir.path, "test.yaml"))
115
116
    # non yaml filename
117
    with TempDirectory() as tempdir:
118
        with pytest.raises(AssertionError):
119
            save(config=dict(x=1), out_dir=tempdir.path, filename="test.txt")
120
121
122
class TestConfigSanityCheck:
123
    def test_cond_err(self):
124
        """Test error message for conditional model."""
125
        with pytest.raises(ValueError) as err_info:
126
            config_sanity_check(
127
                config=dict(
128
                    dataset=dict(
129
                        type="paired",
130
                        format="h5",
131
                        dir=dict(train=None, valid=None, test=None),
132
                        labeled=False,
133
                    ),
134
                    train=dict(
135
                        method="conditional",
136
                        loss=dict(),
137
                        preprocess=dict(),
138
                        optimizer=dict(name="Adam"),
139
                    ),
140
                )
141
            )
142
        assert (
143
            "For conditional model, data have to be labeled, got unlabeled data."
144
            in str(err_info.value)
145
        )
146