Passed
Pull Request — main (#662)
by Yunguan
03:31
created

TestParseLabelLoss.test_parse_old_loss()   A

Complexity

Conditions 2

Size

Total Lines 43
Code Lines 27

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 27
dl 0
loc 43
rs 9.232
c 0
b 0
f 0
cc 2
nop 3
1
import pytest
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
import yaml
3
4
from deepreg.config.v011 import (
5
    parse_image_loss,
6
    parse_label_loss,
7
    parse_loss,
8
    parse_model,
9
    parse_optimizer,
10
    parse_reg_loss,
11
    parse_v011,
12
)
13
14
15
@pytest.mark.parametrize(
16
    ("old_config_path", "latest_config_path"),
17
    [
18
        (
19
            "config/test/grouped_mr_heart_v011.yaml",
20
            "demos/grouped_mr_heart/grouped_mr_heart.yaml",
21
        ),
22
        (
23
            "demos/grouped_mr_heart/grouped_mr_heart.yaml",
24
            "demos/grouped_mr_heart/grouped_mr_heart.yaml",
25
        ),
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
26
    ],
27
)
28
def test_grouped_mr_heart(old_config_path: str, latest_config_path: str):
29
    with open(old_config_path) as file:
30
        old_config = yaml.load(file, Loader=yaml.FullLoader)
31
    with open(latest_config_path) as file:
32
        latest_config = yaml.load(file, Loader=yaml.FullLoader)
33
    updated_config = parse_v011(old_config=old_config)
34
    assert updated_config == latest_config
35
36
37
class TestParseModel:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
38
    config_v011 = {
39
        "model": {
40
            "method": "dvf",
41
            "backbone": "global",
42
            "global": {"num_channel_initial": 32},
43
        }
44
    }
45
    config_latest = {
46
        "method": "dvf",
47
        "backbone": {"name": "global", "num_channel_initial": 32},
48
    }
49
50
    @pytest.mark.parametrize(
51
        ("model_config", "expected"),
52
        [
53
            (config_v011, config_latest),
54
            (config_v011["model"], config_latest),
55
            (config_latest, config_latest),
56
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
57
    )
58
    def test_parse(self, model_config: dict, expected: dict):
59
        got = parse_model(model_config=model_config)
60
        assert got == expected
61
62
63
def test_parse_loss():
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
64
    loss_config = {
65
        "dissimilarity": {
66
            "image": {
67
                "name": "lncc",
68
                "weight": 2.0,
69
                "lncc": {
70
                    "kernel_size": 9,
71
                    "kernel_type": "rectangular",
72
                },
73
            },
74
        }
75
    }
76
    expected = {
77
        "image": {
78
            "name": "lncc",
79
            "weight": 2.0,
80
            "kernel_size": 9,
81
            "kernel_type": "rectangular",
82
        },
83
    }
84
    got = parse_loss(loss_config=loss_config)
85
    assert got == expected
86
87
88
class TestParseImageLoss:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
89
    def test_parse_old_loss(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
90
        loss_config = {
91
            "image": {
92
                "name": "lncc",
93
                "weight": 2.0,
94
                "lncc": {
95
                    "kernel_size": 9,
96
                    "kernel_type": "rectangular",
97
                },
98
            },
99
        }
100
        expected = {
101
            "image": {
102
                "name": "lncc",
103
                "weight": 2.0,
104
                "kernel_size": 9,
105
                "kernel_type": "rectangular",
106
            },
107
        }
108
        got = parse_image_loss(loss_config=loss_config)
109
        assert got == expected
110
111
    def test_parse_multiple_loss(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
112
        loss_config = {
113
            "image": [
114
                {
115
                    "name": "lncc",
116
                    "weight": 0.5,
117
                    "kernel_size": 9,
118
                    "kernel_type": "rectangular",
119
                },
120
                {
121
                    "name": "ssd",
122
                    "weight": 0.5,
123
                },
124
            ],
125
        }
126
127
        got = parse_image_loss(loss_config=loss_config)
128
        assert got == loss_config
129
130
131
class TestParseLabelLoss:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
132
    @pytest.mark.parametrize(
133
        ("name_loss", "expected_config"),
134
        [
135
            (
136
                "multi_scale",
137
                {
138
                    "label": {
139
                        "name": "ssd",
140
                        "weight": 2.0,
141
                        "scales": [0, 1],
142
                    },
143
                },
144
            ),
145
            (
146
                "single_scale",
147
                {
148
                    "label": {
149
                        "name": "dice",
150
                        "weight": 1.0,
151
                    },
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
152
                },
153
            ),
154
        ],
155
    )
156
    def test_parse_old_loss(self, name_loss: str, expected_config: dict):
157
        loss_config = {
158
            "label": {
159
                "name": name_loss,
160
                "single_scale": {
161
                    "loss_type": "dice_generalized",
162
                },
163
                "multi_scale": {
164
                    "loss_type": "mean-squared",
165
                    "loss_scales": [0, 1],
166
                },
167
            },
168
        }
169
170
        if name_loss == "multi_scale":
171
            loss_config["label"]["weight"] = 2.0
172
173
        got = parse_label_loss(loss_config=loss_config)
174
        assert got == expected_config
175
176
    def test_parse_multiple_loss(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
177
        loss_config = {
178
            "label": [
179
                {
180
                    "name": "dice",
181
                    "weight": 1.0,
182
                },
183
                {
184
                    "name": "cross-entropy",
185
                    "weight": 1.0,
186
                },
187
            ],
188
        }
189
190
        got = parse_label_loss(loss_config=loss_config)
191
        assert got == loss_config
192
193
194
class TestParseRegularizationLoss:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
195
    @pytest.mark.parametrize(
196
        ("energy_type", "loss_name", "extra_args"),
197
        [
198
            ("bending", "bending", {}),
199
            ("gradient-l2", "gradient", {"l1": False}),
200
            ("gradient-l1", "gradient", {"l1": True}),
201
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
202
    )
203
    def test_parse_old_loss(self, energy_type: str, loss_name: str, extra_args: dict):
204
205
        loss_config = {
206
            "regularization": {
207
                "energy_type": energy_type,
208
                "weight": 2.0,
209
            }
210
        }
211
        expected = {
212
            "regularization": {
213
                "name": loss_name,
214
                "weight": 2.0,
215
                **extra_args,
216
            },
217
        }
218
        got = parse_reg_loss(loss_config=loss_config)
219
        assert got == expected
220
221
    def test_parse_multiple_reg_loss(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
222
        loss_config = {
223
            "regularization": [
224
                {
225
                    "name": "bending",
226
                    "weight": 2.0,
227
                },
228
                {
229
                    "name": "gradient",
230
                    "weight": 2.0,
231
                    "l1": True,
232
                },
233
            ],
234
        }
235
        got = parse_reg_loss(loss_config=loss_config)
236
        assert got == loss_config
237
238
239
def test_parse_optimizer():
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
240
    opt_config = {
241
        "name": "adam",
242
        "adam": {
243
            "learning_rate": 1.0e-4,
244
        },
245
        "sgd": {
246
            "learning_rate": 1.0e-4,
247
            "momentum": 0.9,
248
        },
249
    }
250
    expected = {
251
        "name": "Adam",
252
        "learning_rate": 1.0e-4,
253
    }
254
    got = parse_optimizer(opt_config=opt_config)
255
    assert got == expected
256