TestConditionalModel.test_build_loss()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
cc 1
nop 4
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/_model/network/ddf_dvf.py
5
"""
6
import itertools
7
from copy import deepcopy
8
from unittest.mock import MagicMock, patch
9
10
import pytest
11
12
from deepreg.model.network import RegistrationModel
13
from deepreg.registry import REGISTRY
14
15
moving_image_size = (1, 3, 5)
16
fixed_image_size = (2, 4, 6)
17
index_size = 2
18
batch_size = 3
19
backbone_args = {
20
    "local": {"extract_levels": [1, 2]},
21
    "global": {"extract_levels": [1, 2]},
22
    "unet": {"depth": 2},
23
}
24
config = {
25
    "backbone": {"num_channel_initial": 4, "control_points": 2},
26
    "loss": {
27
        "image": {"name": "lncc", "weight": 0.1},
28
        "label": {
29
            "name": "dice",
30
            "weight": 1,
31
            "scales": [0, 1],
32
        },
33
        "regularization": {"weight": 0.1, "name": "bending"},
34
    },
35
}
36
37
config_multiple_losses = {
38
    "backbone": {"num_channel_initial": 4, "control_points": 2},
39
    "loss": {
40
        "image": [
41
            {"name": "lncc", "weight": 0.1},
42
            {"name": "ssd", "weight": 0.1},
43
            {"name": "gmi", "weight": 0.1},
44
        ],
45
        "label": {
46
            "name": "dice",
47
            "weight": 1,
48
            "scales": [0, 1],
49
        },
50
        "regularization": {"weight": 0.1, "name": "bending"},
51
    },
52
}
53
54
55
@pytest.fixture
56
def model(method: str, labeled: bool, backbone: str) -> RegistrationModel:
57
    """
58
    A specific registration model object.
59
60
    :param method: name of method
61
    :param labeled: whether the data is labeled
62
    :param backbone: name of backbone
63
    :return: the built object
64
    """
65
    copied = deepcopy(config)
66
    copied["method"] = method
67
    copied["backbone"]["name"] = backbone  # type: ignore
68
    if method == "conditional":
69
        copied["backbone"].pop("control_points", None)  # type: ignore
70
    copied["backbone"].update(backbone_args[backbone])  # type: ignore
71
    return REGISTRY.build_model(  # type: ignore
72
        config=dict(
73
            name=method,  # TODO we store method twice
74
            moving_image_size=moving_image_size,
75
            fixed_image_size=fixed_image_size,
76
            index_size=index_size,
77
            labeled=labeled,
78
            batch_size=batch_size,
79
            config=copied,
80
        )
81
    )
82
83
84
def pytest_generate_tests(metafunc):
85
    """
86
    Test parameter generator.
87
88
    This function is called once per each test function.
89
    It takes the attribute `params` from the test class,
90
    and then use the same `params` for all tests inside the class.
91
    This is specific for test of registration models only.
92
93
    This is modified from the pytest documentation,
94
    where their version defined the params for each test function separately.
95
96
    https://docs.pytest.org/en/stable/example/parametrize.html#parametrizing-test-methods-through-per-class-configuration
97
98
    :param metafunc:
99
    :return:
100
    """
101
    #
102
    funcarglist = metafunc.cls.params
103
    argnames = sorted(funcarglist[0])
104
    metafunc.parametrize(
105
        argnames, [[funcargs[name] for name in argnames] for funcargs in funcarglist]
106
    )
107
108
109
class TestRegistrationModel:
110
    params = [dict(labeled=True), dict(labeled=False)]
111
112
    @pytest.fixture
113
    def empty_model(self, labeled: bool) -> RegistrationModel:
114
        """
115
        A RegistrationModel with build_model and build_loss mocked/overwritten.
116
117
        :param labeled: whether the data is labeled
118
        :return: the mocked object
119
        """
120
        with patch.multiple(
121
            RegistrationModel,
122
            build_model=MagicMock(return_value=None),
123
            build_loss=MagicMock(return_value=None),
124
        ):
125
            return RegistrationModel(
126
                moving_image_size=moving_image_size,
127
                fixed_image_size=fixed_image_size,
128
                index_size=index_size,
129
                labeled=labeled,
130
                batch_size=batch_size,
131
                config=dict(),
132
            )
133
134
    def test_get_config(self, empty_model, labeled):
135
        got = empty_model.get_config()
136
        expected = dict(
137
            moving_image_size=moving_image_size,
138
            fixed_image_size=fixed_image_size,
139
            index_size=index_size,
140
            labeled=labeled,
141
            batch_size=batch_size,
142
            config=dict(),
143
            name="RegistrationModel",
144
        )
145
        assert got == expected
146
147
    def test_build_inputs(self, empty_model, labeled):
148
        inputs = empty_model.build_inputs()
149
        expected_inputs_len = 5 if labeled else 3
150
        assert len(inputs) == expected_inputs_len
151
152
        moving_image = inputs["moving_image"]
153
        fixed_image = inputs["fixed_image"]
154
        indices = inputs["indices"]
155
        assert moving_image.shape == (batch_size, *moving_image_size)
156
        assert fixed_image.shape == (batch_size, *fixed_image_size)
157
        assert indices.shape == (batch_size, index_size)
158
159
        if labeled:
160
            moving_label = inputs["moving_label"]
161
            fixed_label = inputs["fixed_label"]
162
            assert moving_label.shape == (batch_size, *moving_image_size)
163
            assert fixed_label.shape == (batch_size, *fixed_image_size)
164
165
    def test_concat_images(self, empty_model, labeled):
166
        inputs = empty_model.build_inputs()
167
        moving_image = inputs["moving_image"]
168
        fixed_image = inputs["fixed_image"]
169
        if labeled:
170
            moving_label = inputs["moving_label"]
171
            images = empty_model.concat_images(moving_image, fixed_image, moving_label)
172
            assert images.shape == (batch_size, *fixed_image_size, 3)
173
        else:
174
            images = empty_model.concat_images(moving_image, fixed_image)
175
            assert images.shape == (batch_size, *fixed_image_size, 2)
176
177
178
class TestBuildLoss:
179
    params = [
180
        dict(config=config, option=0, expected=2),
181
        dict(config=config, option=1, expected=2),
182
        dict(config=config, option=2, expected=3),
183
        dict(config=config_multiple_losses, option=3, expected=5),
184
    ]
185
186
    def test_image_loss(self, config: dict, option: int, expected: int):
187
        method = "ddf"
188
        backbone = "local"
189
        labeled = True
190
        copied = deepcopy(config)
191
        copied["method"] = method
192
        copied["backbone"]["name"] = backbone
193
        copied["backbone"] = {
194
            **backbone_args[backbone],  # type: ignore
195
            **copied["backbone"],
196
        }
197
198
        if option == 0:
199
            # remove image loss config, so loss is not used
200
            copied["loss"].pop("image")
201
        elif option == 1:
202
            # set image loss weight to zero, so loss is not used
203
            copied["loss"]["image"]["weight"] = 0.0
204
        elif option == 2:
205
            # remove image loss weight, so loss is used with default weight 1
206
            copied["loss"]["image"].pop("weight")
207
208
        ddf_model = REGISTRY.build_model(
209
            config=dict(
210
                name=method,  # TODO we store method twice
211
                moving_image_size=moving_image_size,
212
                fixed_image_size=fixed_image_size,
213
                index_size=index_size,
214
                labeled=labeled,
215
                batch_size=batch_size,
216
                config=copied,
217
            )
218
        )
219
220
        assert len(ddf_model._model.losses) == expected  # type: ignore
221
222
223
class TestDDFModel:
224
    params = [
225
        dict(method=method, labeled=labeled, backbone=backbone)
226
        for method, labeled, backbone in itertools.product(
227
            ["ddf"], [True, False], ["local", "global", "unet"]
228
        )
229
    ]
230
231 View Code Duplication
    def test_build_model(self, model, labeled, backbone):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
232
        expected_outputs_len = 3 if labeled else 2
233
        if backbone == "global":
234
            expected_outputs_len += 1
235
            theta = model._outputs["theta"]
236
            assert theta.shape == (batch_size, 4, 3)
237
        assert len(model._outputs) == expected_outputs_len
238
239
        ddf = model._outputs["ddf"]
240
        pred_fixed_image = model._outputs["pred_fixed_image"]
241
        assert ddf.shape == (batch_size, *fixed_image_size, 3)
242
        assert pred_fixed_image.shape == (batch_size, *fixed_image_size)
243
244
        if labeled:
245
            pred_fixed_label = model._outputs["pred_fixed_label"]
246
            assert pred_fixed_label.shape == (batch_size, *fixed_image_size)
247
248
    def test_build_loss(self, model, labeled, backbone):
249
        expected = 3 if labeled else 2
250
        assert len(model._model.losses) == expected
251
252
    def test_postprocess(self, model, labeled, backbone):
253
        indices, processed = model.postprocess(
254
            inputs=model._inputs, outputs=model._outputs
255
        )
256
        assert indices.shape == (batch_size, index_size)
257
        expected = 7 if labeled else 4
258
        if backbone == "global":
259
            expected += 1
260
        assert len(processed) == expected
261
262
263
class TestDVFModel:
264
    params = [
265
        dict(method=method, labeled=labeled, backbone=backbone)
266
        for method, labeled, backbone in itertools.product(
267
            ["dvf"], [True, False], ["local", "unet"]
268
        )
269
    ]
270
271 View Code Duplication
    def test_build_model(self, model, labeled, backbone):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
272
        expected_outputs_len = 4 if labeled else 3
273
        assert len(model._outputs) == expected_outputs_len
274
275
        dvf = model._outputs["dvf"]
276
        ddf = model._outputs["ddf"]
277
        pred_fixed_image = model._outputs["pred_fixed_image"]
278
        assert dvf.shape == (batch_size, *fixed_image_size, 3)
279
        assert ddf.shape == (batch_size, *fixed_image_size, 3)
280
        assert pred_fixed_image.shape == (batch_size, *fixed_image_size)
281
282
        if labeled:
283
            pred_fixed_label = model._outputs["pred_fixed_label"]
284
            assert pred_fixed_label.shape == (batch_size, *fixed_image_size)
285
286
    def test_build_loss(self, model, labeled, backbone):
287
        expected = 3 if labeled else 2
288
        assert len(model._model.losses) == expected
289
290
    def test_postprocess(self, model, labeled, backbone):
291
        indices, processed = model.postprocess(
292
            inputs=model._inputs, outputs=model._outputs
293
        )
294
        assert indices.shape == (batch_size, index_size)
295
        expected = 8 if labeled else 5
296
        assert len(processed) == expected
297
298
299
class TestConditionalModel:
300
    params = [
301
        dict(method=method, labeled=labeled, backbone=backbone)
302
        for method, labeled, backbone in itertools.product(
303
            ["conditional"], [True], ["local", "unet"]
304
        )
305
    ]
306
307
    def test_build_model(self, model, labeled, backbone):
308
        assert len(model._outputs) == 1
309
        pred_fixed_label = model._outputs["pred_fixed_label"]
310
        assert pred_fixed_label.shape == (batch_size, *fixed_image_size)
311
312
    def test_build_loss(self, model, labeled, backbone):
313
        assert len(model._model.losses) == 1
314
315
    def test_postprocess(self, model, labeled, backbone):
316
        indices, processed = model.postprocess(
317
            inputs=model._inputs, outputs=model._outputs
318
        )
319
        assert indices.shape == (batch_size, index_size)
320
        assert len(processed) == 5
321