Passed
Pull Request — main (#651)
by Yunguan
03:48
created

test.unit.test_config_v011.test_parse_image_loss()   A

Complexity

Conditions 1

Size

Total Lines 21
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 16
dl 0
loc 21
rs 9.6
c 0
b 0
f 0
cc 1
nop 0
1
import pytest
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
import yaml
3
4
from deepreg.config.v011 import (
5
    parse_image_loss,
6
    parse_label_loss,
7
    parse_loss,
8
    parse_model,
9
    parse_optimizer,
10
    parse_reg_loss,
11
    parse_v011,
12
)
13
14
15
@pytest.mark.parametrize(
16
    ("old_config_path", "latest_config_path"),
17
    [
18
        (
19
            "config/test/grouped_mr_heart_v011.yaml",
20
            "demos/grouped_mr_heart/grouped_mr_heart.yaml",
21
        ),
22
        (
23
            "demos/grouped_mr_heart/grouped_mr_heart.yaml",
24
            "demos/grouped_mr_heart/grouped_mr_heart.yaml",
25
        ),
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
26
    ],
27
)
28
def test_grouped_mr_heart(old_config_path: str, latest_config_path: str):
29
    with open(old_config_path) as file:
30
        old_config = yaml.load(file, Loader=yaml.FullLoader)
31
    with open(latest_config_path) as file:
32
        latest_config = yaml.load(file, Loader=yaml.FullLoader)
33
    updated_config = parse_v011(old_config=old_config)
34
    assert updated_config == latest_config
35
36
37
class TestParseModel:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
38
    config_v011 = {
39
        "model": {
40
            "method": "dvf",
41
            "backbone": "global",
42
            "global": {"num_channel_initial": 32},
43
        }
44
    }
45
    config_latest = {
46
        "method": "dvf",
47
        "backbone": {"name": "global", "num_channel_initial": 32},
48
    }
49
50
    @pytest.mark.parametrize(
51
        ("model_config", "expected"),
52
        [
53
            (config_v011, config_latest),
54
            (config_v011["model"], config_latest),
55
            (config_latest, config_latest),
56
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
57
    )
58
    def test_parse(self, model_config: dict, expected: dict):
59
        got = parse_model(model_config=model_config)
60
        assert got == expected
61
62
63
def test_parse_loss():
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
64
    loss_config = {
65
        "dissimilarity": {
66
            "image": {
67
                "name": "lncc",
68
                "weight": 2.0,
69
                "lncc": {
70
                    "kernel_size": 9,
71
                    "kernel_type": "rectangular",
72
                },
73
            },
74
        }
75
    }
76
    expected = {
77
        "image": {
78
            "name": "lncc",
79
            "weight": 2.0,
80
            "kernel_size": 9,
81
            "kernel_type": "rectangular",
82
        },
83
    }
84
    got = parse_loss(loss_config=loss_config)
85
    assert got == expected
86
87
88
def test_parse_image_loss():
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
89
    loss_config = {
90
        "image": {
91
            "name": "lncc",
92
            "weight": 2.0,
93
            "lncc": {
94
                "kernel_size": 9,
95
                "kernel_type": "rectangular",
96
            },
97
        },
98
    }
99
    expected = {
100
        "image": {
101
            "name": "lncc",
102
            "weight": 2.0,
103
            "kernel_size": 9,
104
            "kernel_type": "rectangular",
105
        },
106
    }
107
    got = parse_image_loss(loss_config=loss_config)
108
    assert got == expected
109
110
111
class TestParseLabelLoss:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
112
    def test_label_multi_scale(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
113
        loss_config = {
114
            "label": {
115
                "name": "multi_scale",
116
                "weight": 2.0,
117
                "multi_scale": {
118
                    "loss_type": "mean-squared",
119
                    "loss_scales": [0, 1],
120
                },
121
            },
122
        }
123
        expected = {
124
            "label": {
125
                "name": "ssd",
126
                "weight": 2.0,
127
                "scales": [0, 1],
128
            },
129
        }
130
        got = parse_label_loss(loss_config=loss_config)
131
        assert got == expected
132
133
    def test_label_single_scale(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
134
        loss_config = {
135
            "label": {
136
                "name": "single_scale",
137
                "single_scale": {
138
                    "loss_type": "dice_generalized",
139
                },
140
                "multi_scale": {
141
                    "loss_type": "mean-squared",
142
                    "loss_scales": [0, 1],
143
                },
144
            },
145
        }
146
        expected = {
147
            "label": {
148
                "name": "dice",
149
                "weight": 1.0,
150
            },
151
        }
152
        got = parse_label_loss(loss_config=loss_config)
153
        assert got == expected
154
155
156
@pytest.mark.parametrize(
157
    ("energy_type", "loss_name", "extra_args"),
158
    [
159
        ("bending", "bending", {}),
160
        ("gradient-l2", "gradient", {"l1": False}),
161
        ("gradient-l1", "gradient", {"l1": True}),
162
    ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
163
)
164
def test_parse_reg_loss(energy_type: str, loss_name: str, extra_args: dict):
165
    loss_config = {
166
        "regularization": {
167
            "energy_type": energy_type,
168
            "weight": 2.0,
169
        }
170
    }
171
    expected = {
172
        "regularization": {
173
            "name": loss_name,
174
            "weight": 2.0,
175
            **extra_args,
176
        },
177
    }
178
    got = parse_reg_loss(loss_config=loss_config)
179
    assert got == expected
180
181
182
def test_parse_optimizer():
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
183
    opt_config = {
184
        "name": "adam",
185
        "adam": {
186
            "learning_rate": 1.0e-4,
187
        },
188
        "sgd": {
189
            "learning_rate": 1.0e-4,
190
            "momentum": 0.9,
191
        },
192
    }
193
    expected = {
194
        "name": "Adam",
195
        "learning_rate": 1.0e-4,
196
    }
197
    got = parse_optimizer(opt_config=opt_config)
198
    assert got == expected
199