Completed
Push — main ( d3edf2...e8f714 )
by Yunguan
21s queued 14s
created

TestParseLoss.test_parse()   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 8
dl 0
loc 10
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
from typing import Dict
2
3
import pytest
4
import yaml
5
6
from deepreg.config.v011 import (
7
    parse_data,
8
    parse_image_loss,
9
    parse_label_loss,
10
    parse_loss,
11
    parse_model,
12
    parse_optimizer,
13
    parse_reg_loss,
14
    parse_v011,
15
)
16
17
18
@pytest.mark.parametrize(
19
    ("old_config_path", "latest_config_path"),
20
    [
21
        (
22
            "config/test/grouped_mr_heart_v011.yaml",
23
            "demos/grouped_mr_heart/grouped_mr_heart.yaml",
24
        ),
25
        (
26
            "demos/grouped_mr_heart/grouped_mr_heart.yaml",
27
            "demos/grouped_mr_heart/grouped_mr_heart.yaml",
28
        ),
29
    ],
30
)
31
def test_grouped_mr_heart(old_config_path: str, latest_config_path: str):
32
    with open(old_config_path) as file:
33
        old_config = yaml.load(file, Loader=yaml.FullLoader)
34
    with open(latest_config_path) as file:
35
        latest_config = yaml.load(file, Loader=yaml.FullLoader)
36
    updated_config = parse_v011(old_config=old_config)
37
    assert updated_config == latest_config
38
39
40
class TestParseData:
41
    config_v011 = {
42
        "dir": {
43
            "train": "dir_train",
44
            "test": "dir_test",
45
        },
46
        "format": "h5",
47
        "labeled": True,
48
        "type": "paired",
49
    }
50
    config_latest = {
51
        "train": {
52
            "dir": "dir_train",
53
            "format": "h5",
54
            "labeled": True,
55
        },
56
        "test": {
57
            "dir": "dir_test",
58
            "format": "h5",
59
            "labeled": True,
60
        },
61
        "type": "paired",
62
    }
63
64
    @pytest.mark.parametrize(
65
        ("data_config", "expected"),
66
        [
67
            (config_v011, config_latest),
68
            (config_latest, config_latest),
69
        ],
70
    )
71
    def test_parse(self, data_config: Dict, expected: Dict):
72
        got = parse_data(data_config=data_config)
73
        assert got == expected
74
75
76
class TestParseModel:
77
    config_v011 = {
78
        "model": {
79
            "method": "dvf",
80
            "backbone": "global",
81
            "global": {"num_channel_initial": 32, "extract_levels": [0, 1, 2]},
82
        }
83
    }
84
    config_latest = {
85
        "method": "dvf",
86
        "backbone": {"name": "global", "num_channel_initial": 32, "depth": 2},
87
    }
88
89
    @pytest.mark.parametrize(
90
        ("model_config", "expected"),
91
        [
92
            (config_v011, config_latest),
93
            (config_v011["model"], config_latest),
94
            (config_latest, config_latest),
95
        ],
96
    )
97
    def test_parse(self, model_config: Dict, expected: Dict):
98
        got = parse_model(model_config=model_config)
99
        assert got == expected
100
101
102
class TestParseLoss:
103
    config_v011 = {
104
        "dissimilarity": {
105
            "image": {
106
                "name": "lncc",
107
                "weight": 2.0,
108
                "lncc": {
109
                    "kernel_size": 9,
110
                    "kernel_type": "rectangular",
111
                },
112
            },
113
        }
114
    }
115
    config_latest = {
116
        "image": {
117
            "name": "lncc",
118
            "weight": 2.0,
119
            "kernel_size": 9,
120
            "kernel_type": "rectangular",
121
        },
122
    }
123
124
    @pytest.mark.parametrize(
125
        ("loss_config", "expected"),
126
        [
127
            (config_v011, config_latest),
128
            (config_latest, config_latest),
129
        ],
130
    )
131
    def test_parse(self, loss_config: Dict, expected: Dict):
132
        got = parse_loss(loss_config=loss_config)
133
        assert got == expected
134
135
136
class TestParseImageLoss:
137
    config_v011 = {
138
        "image": {
139
            "name": "lncc",
140
            "weight": 2.0,
141
            "lncc": {
142
                "kernel_size": 9,
143
                "kernel_type": "rectangular",
144
            },
145
        },
146
    }
147
    config_latest = {
148
        "image": {
149
            "name": "lncc",
150
            "weight": 2.0,
151
            "kernel_size": 9,
152
            "kernel_type": "rectangular",
153
        },
154
    }
155
156
    @pytest.mark.parametrize(
157
        ("loss_config", "expected"),
158
        [
159
            (config_v011, config_latest),
160
            (config_latest, config_latest),
161
        ],
162
    )
163
    def test_parse(self, loss_config: Dict, expected: Dict):
164
        got = parse_image_loss(loss_config=loss_config)
165
        assert got == expected
166
167
    def test_parse_multiple_loss(self):
168
        loss_config = {
169
            "image": [
170
                {
171
                    "name": "lncc",
172
                    "weight": 0.5,
173
                    "kernel_size": 9,
174
                    "kernel_type": "rectangular",
175
                },
176
                {
177
                    "name": "ssd",
178
                    "weight": 0.5,
179
                },
180
            ],
181
        }
182
183
        got = parse_image_loss(loss_config=loss_config)
184
        assert got == loss_config
185
186
187
class TestParseLabelLoss:
188
    @pytest.mark.parametrize(
189
        ("name_loss", "expected_config"),
190
        [
191
            (
192
                "multi_scale",
193
                {
194
                    "label": {
195
                        "name": "ssd",
196
                        "weight": 2.0,
197
                        "scales": [0, 1],
198
                    },
199
                },
200
            ),
201
            (
202
                "single_scale",
203
                {
204
                    "label": {
205
                        "name": "dice",
206
                        "weight": 1.0,
207
                    },
208
                },
209
            ),
210
        ],
211
    )
212
    def test_parse_outdated_loss(self, name_loss: str, expected_config: Dict):
213
        outdated_config = {
214
            "label": {
215
                "name": name_loss,
216
                "single_scale": {
217
                    "loss_type": "dice_generalized",
218
                },
219
                "multi_scale": {
220
                    "loss_type": "mean-squared",
221
                    "loss_scales": [0, 1],
222
                },
223
            },
224
        }
225
226
        if name_loss == "multi_scale":
227
            outdated_config["label"]["weight"] = 2.0  # type: ignore
228
229
        got = parse_label_loss(loss_config=outdated_config)
230
        assert got == expected_config
231
232
    def test_parse_background_weight(self):
233
        outdated_config = {
234
            "label": {
235
                "name": "dice",
236
                "weight": 1.0,
237
                "neg_weight": 2.0,
238
            },
239
        }
240
        expected_config = {
241
            "label": {
242
                "name": "dice",
243
                "weight": 1.0,
244
                "background_weight": 2.0,
245
            },
246
        }
247
        got = parse_label_loss(loss_config=outdated_config)
248
        assert got == expected_config
249
250
    def test_parse_multiple_loss(self):
251
        loss_config = {
252
            "label": [
253
                {
254
                    "name": "dice",
255
                    "weight": 1.0,
256
                },
257
                {
258
                    "name": "cross-entropy",
259
                    "weight": 1.0,
260
                },
261
            ],
262
        }
263
264
        got = parse_label_loss(loss_config=loss_config)
265
        assert got == loss_config
266
267
268
class TestParseRegularizationLoss:
269
    @pytest.mark.parametrize(
270
        ("energy_type", "loss_name", "extra_args"),
271
        [
272
            ("bending", "bending", {}),
273
            ("gradient-l2", "gradient", {"l1": False}),
274
            ("gradient-l1", "gradient", {"l1": True}),
275
        ],
276
    )
277
    def test_parse_outdated_loss(
278
        self, energy_type: str, loss_name: str, extra_args: Dict
279
    ):
280
281
        loss_config = {
282
            "regularization": {
283
                "energy_type": energy_type,
284
                "weight": 2.0,
285
            }
286
        }
287
        expected = {
288
            "regularization": {
289
                "name": loss_name,
290
                "weight": 2.0,
291
                **extra_args,
292
            },
293
        }
294
        got = parse_reg_loss(loss_config=loss_config)
295
        assert got == expected
296
297
    def test_parse_multiple_reg_loss(self):
298
        loss_config = {
299
            "regularization": [
300
                {
301
                    "name": "bending",
302
                    "weight": 2.0,
303
                },
304
                {
305
                    "name": "gradient",
306
                    "weight": 2.0,
307
                    "l1": True,
308
                },
309
            ],
310
        }
311
        got = parse_reg_loss(loss_config=loss_config)
312
        assert got == loss_config
313
314
315
class TestParseOptimizer:
316
    config_v011 = {
317
        "name": "adam",
318
        "adam": {
319
            "learning_rate": 1.0e-4,
320
        },
321
        "sgd": {
322
            "learning_rate": 1.0e-4,
323
            "momentum": 0.9,
324
        },
325
    }
326
    config_latest = {
327
        "name": "Adam",
328
        "learning_rate": 1.0e-4,
329
    }
330
331
    @pytest.mark.parametrize(
332
        ("opt_config", "expected"),
333
        [
334
            (config_v011, config_latest),
335
            (config_latest, config_latest),
336
        ],
337
    )
338
    def test_parse(self, opt_config: Dict, expected: Dict):
339
        got = parse_optimizer(opt_config=opt_config)
340
        assert got == expected
341