TestResizeCPTransform.test_attributes()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 8
dl 0
loc 10
rs 10
c 0
b 0
f 0
cc 2
nop 3
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/layer
5
"""
6
7
import numpy as np
8
import pytest
9
import tensorflow as tf
10
11
import deepreg.model.layer as layer
12
13
14
@pytest.mark.parametrize("layer_name", ["conv3d", "deconv3d"])
15
@pytest.mark.parametrize("norm_name", ["batch", "layer"])
16
@pytest.mark.parametrize("activation", ["relu", "elu"])
17
def test_norm_block(layer_name: str, norm_name: str, activation: str):
18
    """
19
    Test output shapes and configs.
20
21
    :param layer_name: layer_name for layer definition
22
    :param norm_name: norm_name for layer definition
23
    :param activation: activation for layer definition
24
    """
25
    input_size = (2, 3, 4, 5, 6)  # (batch, *shape, ch)
26
    norm_block = layer.NormBlock(
27
        layer_name=layer_name,
28
        norm_name=norm_name,
29
        activation=activation,
30
        filters=3,
31
        kernel_size=1,
32
        padding="same",
33
    )
34
    inputs = tf.ones(shape=input_size)
35
    outputs = norm_block(inputs)
36
    assert outputs.shape == input_size[:-1] + (3,)
37
38
    config = norm_block.get_config()
39
    assert config == dict(
40
        layer_name=layer_name,
41
        norm_name=norm_name,
42
        activation=activation,
43
        filters=3,
44
        kernel_size=1,
45
        padding="same",
46
        name="norm_block",
47
        trainable=True,
48
        dtype="float32",
49
    )
50
51
52
class TestWarping:
53
    @pytest.mark.parametrize(
54
        ("moving_image_size", "fixed_image_size"),
55
        [
56
            ((1, 2, 3), (3, 4, 5)),
57
            ((1, 2, 3), (1, 2, 3)),
58
        ],
59
    )
60
    def test_forward(self, moving_image_size, fixed_image_size):
61
        batch_size = 2
62
        image = tf.ones(shape=(batch_size,) + moving_image_size)
63
        ddf = tf.ones(shape=(batch_size,) + fixed_image_size + (3,))
64
        outputs = layer.Warping(fixed_image_size=fixed_image_size)([ddf, image])
65
        assert outputs.shape == (batch_size, *fixed_image_size)
66
67
    def test_get_config(self):
68
        warping = layer.Warping(fixed_image_size=(2, 3, 4))
69
        config = warping.get_config()
70
        assert config == dict(
71
            fixed_image_size=(2, 3, 4),
72
            name="warping",
73
            trainable=True,
74
            dtype="float32",
75
        )
76
77
78
@pytest.mark.parametrize("layer_name", ["conv3d", "deconv3d"])
79
@pytest.mark.parametrize("norm_name", ["batch", "layer"])
80
@pytest.mark.parametrize("activation", ["relu", "elu"])
81
@pytest.mark.parametrize("num_layers", [2, 3])
82
def test_res_block(layer_name: str, norm_name: str, activation: str, num_layers: int):
83
    """
84
    Test output shapes and configs.
85
86
    :param layer_name: layer_name for layer definition
87
    :param norm_name: norm_name for layer definition
88
    :param activation: activation for layer definition
89
    :param num_layers: number of blocks in res block
90
    """
91
    ch = 3
92
    input_size = (2, 3, 4, 5, ch)  # (batch, *shape, ch)
93
    res_block = layer.ResidualBlock(
94
        layer_name=layer_name,
95
        num_layers=num_layers,
96
        norm_name=norm_name,
97
        activation=activation,
98
        filters=ch,
99
        kernel_size=3,
100
        padding="same",
101
    )
102
    inputs = tf.ones(shape=input_size)
103
    outputs = res_block(inputs)
104
    assert outputs.shape == input_size[:-1] + (3,)
105
106
    config = res_block.get_config()
107
    assert config == dict(
108
        layer_name=layer_name,
109
        num_layers=num_layers,
110
        norm_name=norm_name,
111
        activation=activation,
112
        filters=ch,
113
        kernel_size=3,
114
        padding="same",
115
        name="res_block",
116
        trainable=True,
117
        dtype="float32",
118
    )
119
120
121
class TestIntDVF:
122
    def test_forward(self):
123
        """
124
        Test output shape and config.
125
        """
126
127
        fixed_image_size = (8, 9, 10)
128
        input_shape = (2, *fixed_image_size, 3)
129
130
        int_layer = layer.IntDVF(fixed_image_size=fixed_image_size)
131
132
        inputs = tf.ones(shape=input_shape)
133
        outputs = int_layer(inputs)
134
        assert outputs.shape == input_shape
135
136
        config = int_layer.get_config()
137
        assert config == dict(
138
            fixed_image_size=fixed_image_size,
139
            num_steps=7,
140
            name="int_dvf",
141
            trainable=True,
142
            dtype="float32",
143
        )
144
145
    def test_err(self):
146
        with pytest.raises(AssertionError):
147
            layer.IntDVF(fixed_image_size=(2, 3))
148
149
150
class TestResizeCPTransform:
151
    @pytest.mark.parametrize(
152
        "parameter,cp_spacing", [((8, 8, 8), 8), ((8, 24, 16), (8, 24, 16))]
153
    )
154
    def test_attributes(self, parameter, cp_spacing):
155
        model = layer.ResizeCPTransform(cp_spacing)
156
157
        if isinstance(cp_spacing, int):
158
            cp_spacing = [cp_spacing] * 3
159
        assert list(model.cp_spacing) == list(parameter)
160
        assert model.kernel_sigma == [0.44 * cp for cp in cp_spacing]
161
162
    @pytest.mark.parametrize(
163
        "input_size,output_size,cp_spacing",
164
        [
165
            ((1, 8, 8, 8, 3), (12, 8, 12), (8, 16, 8)),
166
            ((1, 8, 8, 8, 3), (12, 12, 12), 8),
167
        ],
168
    )
169
    def test_build(self, input_size, output_size, cp_spacing):
170
        model = layer.ResizeCPTransform(cp_spacing)
171
        model.build(input_size)
172
173
        assert [a == b for a, b, in zip(model._output_shape, output_size)]
174
175
    @pytest.mark.parametrize(
176
        "input_size,output_size,cp_spacing",
177
        [
178
            ((1, 68, 68, 68, 3), (1, 12, 8, 12, 3), (8, 16, 8)),
179
            ((1, 68, 68, 68, 3), (1, 12, 12, 12, 3), 8),
180
        ],
181
    )
182
    def test_call(self, input_size, output_size, cp_spacing):
183
        model = layer.ResizeCPTransform(cp_spacing)
184
        model.build(input_size)
185
186
        input = tf.random.normal(shape=input_size, dtype=tf.float32)
187
        output = model(input)
188
189
        assert output.shape == output_size
190
191
192
class TestBSplines3DTransform:
193
    """
194
    Test the layer.BSplines3DTransform class,
195
    its default attributes and its call() function.
196
    """
197
198
    @pytest.mark.parametrize(
199
        "input_size,cp",
200
        [((1, 8, 8, 8, 3), 8), ((1, 8, 8, 8, 3), (8, 16, 12))],
201
    )
202
    def test_init(self, input_size, cp):
203
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
204
205
        if isinstance(cp, int):
206
            cp = (cp, cp, cp)
207
208
        assert model.cp_spacing == cp
209
210
    @pytest.mark.parametrize(
211
        "input_size,cp",
212
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
213
    )
214
    def generate_filter_coefficients(self, cp_spacing):
215
216
        b = {
217
            0: lambda u: np.float64((1 - u) ** 3 / 6),
218
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
219
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
220
            3: lambda u: np.float64(u ** 3 / 6),
221
        }
222
223
        filters = np.zeros(
224
            (
225
                4 * cp_spacing[0],
226
                4 * cp_spacing[1],
227
                4 * cp_spacing[2],
228
                3,
229
                3,
230
            ),
231
            dtype=np.float32,
232
        )
233
234
        for u in range(cp_spacing[0]):
235
            for v in range(cp_spacing[1]):
236
                for w in range(cp_spacing[2]):
237
                    for x in range(4):
238
                        for y in range(4):
239
                            for z in range(4):
240
                                for it_dim in range(3):
241
                                    u_norm = 1 - (u + 0.5) / cp_spacing[0]
242
                                    v_norm = 1 - (v + 0.5) / cp_spacing[1]
243
                                    w_norm = 1 - (w + 0.5) / cp_spacing[2]
244
                                    filters[
245
                                        x * cp_spacing[0] + u,
246
                                        y * cp_spacing[1] + v,
247
                                        z * cp_spacing[2] + w,
248
                                        it_dim,
249
                                        it_dim,
250
                                    ] = (
251
                                        b[x](u_norm) * b[y](v_norm) * b[z](w_norm)
252
                                    )
253
        return filters
254
255
    @pytest.mark.parametrize(
256
        "input_size,cp",
257
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
258
    )
259
    def test_build(self, input_size, cp):
260
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
261
262
        model.build(input_size)
263
        assert model.filter.shape == (
264
            4 * cp[0],
265
            4 * cp[1],
266
            4 * cp[2],
267
            3,
268
            3,
269
        )
270
271
    @pytest.mark.parametrize(
272
        "input_size,cp",
273
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
274
    )
275
    def test_coefficients(self, input_size, cp):
276
277
        filters = self.generate_filter_coefficients(cp_spacing=cp)
278
279
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
280
        model.build(input_size)
281
282
        assert np.allclose(filters, model.filter.numpy(), atol=1e-8)
283
284
    @pytest.mark.parametrize(
285
        "input_size,cp",
286
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
287
    )
288
    def test_interpolation(self, input_size, cp):
289
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
290
        model.build(input_size)
291
292
        vol_shape = input_size[1:-1]
293
        num_cp = (
294
            [input_size[0]]
295
            + [int(np.ceil(isize / cpsize) + 3) for isize, cpsize in zip(vol_shape, cp)]
296
            + [input_size[-1]]
297
        )
298
299
        field = tf.random.normal(shape=num_cp, dtype=tf.float32)
300
301
        ddf = model.call(field)
302
        assert ddf.shape == input_size
303
304
305
class TestResize3d:
306
    @pytest.mark.parametrize(
307
        ("input_shape", "resize_shape", "output_shape"),
308
        [
309
            ((1, 2, 3), (3, 4, 5), (3, 4, 5)),
310
            ((2, 1, 2, 3), (3, 4, 5), (2, 3, 4, 5)),
311
            ((2, 1, 2, 3, 1), (3, 4, 5), (2, 3, 4, 5, 1)),
312
            ((2, 1, 2, 3, 6), (3, 4, 5), (2, 3, 4, 5, 6)),
313
            ((1, 2, 3), (1, 2, 3), (1, 2, 3)),
314
            ((2, 1, 2, 3), (1, 2, 3), (2, 1, 2, 3)),
315
            ((2, 1, 2, 3, 1), (1, 2, 3), (2, 1, 2, 3, 1)),
316
            ((2, 1, 2, 3, 6), (1, 2, 3), (2, 1, 2, 3, 6)),
317
        ],
318
    )
319
    def test_forward(self, input_shape, resize_shape, output_shape):
320
        inputs = tf.ones(shape=input_shape)
321
        outputs = layer.Resize3d(shape=resize_shape)(inputs)
322
        assert outputs.shape == output_shape
323
324
    def test_get_config(self):
325
        resize = layer.Resize3d(shape=(2, 3, 4))
326
        config = resize.get_config()
327
        assert config == dict(
328
            shape=(2, 3, 4),
329
            method=tf.image.ResizeMethod.BILINEAR,
330
            name="resize3d",
331
            trainable=True,
332
            dtype="float32",
333
        )
334
335
    def test_shape_err(self):
336
        with pytest.raises(AssertionError):
337
            layer.Resize3d(shape=(2, 3))
338
339
    def test_image_shape_err(self):
340
        with pytest.raises(ValueError) as err_info:
341
            resize = layer.Resize3d(shape=(2, 3, 4))
342
            resize(tf.ones(1, 1))
343
        assert "Resize3d takes input image of dimension 3 or 4 or 5" in str(
344
            err_info.value
345
        )
346