Passed
Pull Request — main (#699)
by Yunguan
01:23
created

TestParseLabelLoss.test_parse_background_weight()   A

Complexity

Conditions 1

Size

Total Lines 17
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 17
rs 9.75
c 0
b 0
f 0
cc 1
nop 1
1
import pytest
2
import yaml
3
4
from deepreg.config.v011 import (
5
    parse_image_loss,
6
    parse_label_loss,
7
    parse_loss,
8
    parse_model,
9
    parse_optimizer,
10
    parse_reg_loss,
11
    parse_v011,
12
)
13
14
15
@pytest.mark.parametrize(
16
    ("old_config_path", "latest_config_path"),
17
    [
18
        (
19
            "config/test/grouped_mr_heart_v011.yaml",
20
            "demos/grouped_mr_heart/grouped_mr_heart.yaml",
21
        ),
22
        (
23
            "demos/grouped_mr_heart/grouped_mr_heart.yaml",
24
            "demos/grouped_mr_heart/grouped_mr_heart.yaml",
25
        ),
26
    ],
27
)
28
def test_grouped_mr_heart(old_config_path: str, latest_config_path: str):
29
    with open(old_config_path) as file:
30
        old_config = yaml.load(file, Loader=yaml.FullLoader)
31
    with open(latest_config_path) as file:
32
        latest_config = yaml.load(file, Loader=yaml.FullLoader)
33
    updated_config = parse_v011(old_config=old_config)
34
    assert updated_config == latest_config
35
36
37
class TestParseModel:
38
    config_v011 = {
39
        "model": {
40
            "method": "dvf",
41
            "backbone": "global",
42
            "global": {"num_channel_initial": 32, "extract_levels": [0, 1, 2]},
43
        }
44
    }
45
    config_latest = {
46
        "method": "dvf",
47
        "backbone": {"name": "global", "num_channel_initial": 32, "depth": 2},
48
    }
49
50
    @pytest.mark.parametrize(
51
        ("model_config", "expected"),
52
        [
53
            (config_v011, config_latest),
54
            (config_v011["model"], config_latest),
55
            (config_latest, config_latest),
56
        ],
57
    )
58
    def test_parse(self, model_config: dict, expected: dict):
59
        got = parse_model(model_config=model_config)
60
        assert got == expected
61
62
63
def test_parse_loss():
64
    loss_config = {
65
        "dissimilarity": {
66
            "image": {
67
                "name": "lncc",
68
                "weight": 2.0,
69
                "lncc": {
70
                    "kernel_size": 9,
71
                    "kernel_type": "rectangular",
72
                },
73
            },
74
        }
75
    }
76
    expected = {
77
        "image": {
78
            "name": "lncc",
79
            "weight": 2.0,
80
            "kernel_size": 9,
81
            "kernel_type": "rectangular",
82
        },
83
    }
84
    got = parse_loss(loss_config=loss_config)
85
    assert got == expected
86
87
88
class TestParseImageLoss:
89
    def test_parse_outdated_loss(self):
90
        loss_config = {
91
            "image": {
92
                "name": "lncc",
93
                "weight": 2.0,
94
                "lncc": {
95
                    "kernel_size": 9,
96
                    "kernel_type": "rectangular",
97
                },
98
            },
99
        }
100
        expected = {
101
            "image": {
102
                "name": "lncc",
103
                "weight": 2.0,
104
                "kernel_size": 9,
105
                "kernel_type": "rectangular",
106
            },
107
        }
108
        got = parse_image_loss(loss_config=loss_config)
109
        assert got == expected
110
111
    def test_parse_multiple_loss(self):
112
        loss_config = {
113
            "image": [
114
                {
115
                    "name": "lncc",
116
                    "weight": 0.5,
117
                    "kernel_size": 9,
118
                    "kernel_type": "rectangular",
119
                },
120
                {
121
                    "name": "ssd",
122
                    "weight": 0.5,
123
                },
124
            ],
125
        }
126
127
        got = parse_image_loss(loss_config=loss_config)
128
        assert got == loss_config
129
130
131
class TestParseLabelLoss:
132
    @pytest.mark.parametrize(
133
        ("name_loss", "expected_config"),
134
        [
135
            (
136
                "multi_scale",
137
                {
138
                    "label": {
139
                        "name": "ssd",
140
                        "weight": 2.0,
141
                        "scales": [0, 1],
142
                    },
143
                },
144
            ),
145
            (
146
                "single_scale",
147
                {
148
                    "label": {
149
                        "name": "dice",
150
                        "weight": 1.0,
151
                    },
152
                },
153
            ),
154
        ],
155
    )
156
    def test_parse_outdated_loss(self, name_loss: str, expected_config: dict):
157
        outdated_config = {
158
            "label": {
159
                "name": name_loss,
160
                "single_scale": {
161
                    "loss_type": "dice_generalized",
162
                },
163
                "multi_scale": {
164
                    "loss_type": "mean-squared",
165
                    "loss_scales": [0, 1],
166
                },
167
            },
168
        }
169
170
        if name_loss == "multi_scale":
171
            outdated_config["label"]["weight"] = 2.0  # type: ignore
172
173
        got = parse_label_loss(loss_config=outdated_config)
174
        assert got == expected_config
175
176
    def test_parse_background_weight(self):
177
        outdated_config = {
178
            "label": {
179
                "name": "dice",
180
                "weight": 1.0,
181
                "neg_weight": 2.0,
182
            },
183
        }
184
        expected_config = {
185
            "label": {
186
                "name": "dice",
187
                "weight": 1.0,
188
                "background_weight": 2.0,
189
            },
190
        }
191
        got = parse_label_loss(loss_config=outdated_config)
192
        assert got == expected_config
193
194
    def test_parse_multiple_loss(self):
195
        loss_config = {
196
            "label": [
197
                {
198
                    "name": "dice",
199
                    "weight": 1.0,
200
                },
201
                {
202
                    "name": "cross-entropy",
203
                    "weight": 1.0,
204
                },
205
            ],
206
        }
207
208
        got = parse_label_loss(loss_config=loss_config)
209
        assert got == loss_config
210
211
212
class TestParseRegularizationLoss:
213
    @pytest.mark.parametrize(
214
        ("energy_type", "loss_name", "extra_args"),
215
        [
216
            ("bending", "bending", {}),
217
            ("gradient-l2", "gradient", {"l1": False}),
218
            ("gradient-l1", "gradient", {"l1": True}),
219
        ],
220
    )
221
    def test_parse_outdated_loss(
222
        self, energy_type: str, loss_name: str, extra_args: dict
223
    ):
224
225
        loss_config = {
226
            "regularization": {
227
                "energy_type": energy_type,
228
                "weight": 2.0,
229
            }
230
        }
231
        expected = {
232
            "regularization": {
233
                "name": loss_name,
234
                "weight": 2.0,
235
                **extra_args,
236
            },
237
        }
238
        got = parse_reg_loss(loss_config=loss_config)
239
        assert got == expected
240
241
    def test_parse_multiple_reg_loss(self):
242
        loss_config = {
243
            "regularization": [
244
                {
245
                    "name": "bending",
246
                    "weight": 2.0,
247
                },
248
                {
249
                    "name": "gradient",
250
                    "weight": 2.0,
251
                    "l1": True,
252
                },
253
            ],
254
        }
255
        got = parse_reg_loss(loss_config=loss_config)
256
        assert got == loss_config
257
258
259
def test_parse_optimizer():
260
    opt_config = {
261
        "name": "adam",
262
        "adam": {
263
            "learning_rate": 1.0e-4,
264
        },
265
        "sgd": {
266
            "learning_rate": 1.0e-4,
267
            "momentum": 0.9,
268
        },
269
    }
270
    expected = {
271
        "name": "Adam",
272
        "learning_rate": 1.0e-4,
273
    }
274
    got = parse_optimizer(opt_config=opt_config)
275
    assert got == expected
276