Passed
Pull Request — main (#656)
by Yunguan
02:56
created

test.unit.test_layer.test_norm_block()   A

Complexity

Conditions 1

Size

Total Lines 35
Code Lines 26

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 26
dl 0
loc 35
rs 9.256
c 0
b 0
f 0
cc 1
nop 3
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/layer
5
"""
6
from typing import Optional
7
8
import numpy as np
9
import pytest
10
import tensorflow as tf
11
12
import deepreg.model.layer as layer
13
14
15
@pytest.mark.parametrize(
16
    ("input_shape", "output_shape", "expected_shape"),
17
    [
18
        ((6, 7, 8), (12, 14, 16), (12, 14, 16)),
19
        ((6, 7, 8), (11, 13, 15), (11, 13, 15)),
20
        ((6, 7, 8), None, (12, 14, 16)),
21
    ],
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
22
)
23
def test_deconv3d(input_shape, output_shape: Optional[tuple], expected_shape: tuple):
24
    """
25
    Test output shapes and configs.
26
27
    :param input_shape: input shape for layer definition
28
    :param output_shape: output shape for layer definition
29
    :param expected_shape: expected output shape
30
    """
31
    batch_size = 5
32
33
    deconv3d = layer.Deconv3d(filters=3, strides=2, output_shape=output_shape)
34
35
    inputs = tf.ones(shape=(batch_size,) + input_shape + (1,))
36
    output = deconv3d(inputs)
37
38
    assert output.shape == (batch_size,) + expected_shape + (3,)
39
40
    config = deconv3d.get_config()
41
    assert config == dict(
42
        filters=3,
43
        output_shape=output_shape,
44
        kernel_size=3,
45
        strides=2,
46
        padding="same",
47
        name="deconv3d",
48
        trainable=True,
49
        dtype="float32",
50
    )
51
52
53
@pytest.mark.parametrize("layer_name", ["conv3d", "deconv3d"])
54
@pytest.mark.parametrize("norm_name", ["batch", "layer"])
55
@pytest.mark.parametrize("activation", ["relu", "elu"])
56
def test_norm_block(layer_name: str, norm_name: str, activation: str):
57
    """
58
    Test output shapes and configs.
59
60
    :param layer_name: layer_name for layer definition
61
    :param norm_name: norm_name for layer definition
62
    :param activation: activation for layer definition
63
    """
64
    input_size = (2, 3, 4, 5, 6)  # (batch, *shape, ch)
65
    norm_block = layer.NormBlock(
66
        layer_name=layer_name,
67
        norm_name=norm_name,
68
        activation=activation,
69
        filters=3,
70
        kernel_size=1,
71
        padding="same",
72
    )
73
    inputs = tf.ones(shape=input_size)
74
    outputs = norm_block(inputs)
75
    assert outputs.shape == input_size[:-1] + (3,)
76
77
    config = norm_block.get_config()
78
    assert config == dict(
79
        layer_name=layer_name,
80
        norm_name=norm_name,
81
        activation=activation,
82
        filters=3,
83
        kernel_size=1,
84
        padding="same",
85
        name="norm_block",
86
        trainable=True,
87
        dtype="float32",
88
    )
89
90
91
def test_upsample_resnet_block():
92
    """
93
    Test the layer.UpSampleResnetBlock class and its default attributes.
94
    """
95
    batch_size = 5
96
    channels = 4
97
    input_size = (32, 32, 16)
98
    output_size = (64, 64, 32)
99
100
    input_tensor_size = (batch_size,) + input_size + (channels,)
101
    skip_tensor_size = (batch_size,) + output_size + (channels // 2,)
102
103
    model = layer.UpSampleResnetBlock(8)
104
    model.build([input_tensor_size, skip_tensor_size])
105
106
    assert model._filters == 8
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _filters was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
107
    assert model._concat is False
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _concat was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
108
    assert isinstance(model._conv3d_block, layer.Conv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
109
    assert isinstance(model._residual_block, layer.ResidualConv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _residual_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
110
    assert isinstance(model._deconv3d_block, layer.Deconv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _deconv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
111
112
113
class TestWarping:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
114
    @pytest.mark.parametrize(
115
        ("moving_image_size", "fixed_image_size"),
116
        [
117
            ((1, 2, 3), (3, 4, 5)),
118
            ((1, 2, 3), (1, 2, 3)),
119
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
120
    )
121
    def test_forward(self, moving_image_size, fixed_image_size):
122
        batch_size = 2
123
        image = tf.ones(shape=(batch_size,) + moving_image_size)
124
        ddf = tf.ones(shape=(batch_size,) + fixed_image_size + (3,))
125
        outputs = layer.Warping(fixed_image_size=fixed_image_size)([ddf, image])
126
        assert outputs.shape == (batch_size, *fixed_image_size)
127
128
    def test_get_config(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
129
        warping = layer.Warping(fixed_image_size=(2, 3, 4))
130
        config = warping.get_config()
131
        assert config == dict(
132
            fixed_image_size=(2, 3, 4),
133
            name="warping",
134
            trainable=True,
135
            dtype="float32",
136
        )
137
138
139
@pytest.mark.parametrize("layer_name", ["conv3d", "deconv3d"])
140
@pytest.mark.parametrize("norm_name", ["batch", "layer"])
141
@pytest.mark.parametrize("activation", ["relu", "elu"])
142
@pytest.mark.parametrize("num_layers", [2, 3])
143
def test_res_block(layer_name: str, norm_name: str, activation: str, num_layers: int):
144
    """
145
    Test output shapes and configs.
146
147
    :param layer_name: layer_name for layer definition
148
    :param norm_name: norm_name for layer definition
149
    :param activation: activation for layer definition
150
    :param num_layers: number of blocks in res block
151
    """
152
    ch = 3
153
    input_size = (2, 3, 4, 5, ch)  # (batch, *shape, ch)
154
    res_block = layer.ResidualBlock(
155
        layer_name=layer_name,
156
        num_layers=num_layers,
157
        norm_name=norm_name,
158
        activation=activation,
159
        filters=ch,
160
        kernel_size=3,
161
        padding="same",
162
    )
163
    inputs = tf.ones(shape=input_size)
164
    outputs = res_block(inputs)
165
    assert outputs.shape == input_size[:-1] + (3,)
166
167
    config = res_block.get_config()
168
    assert config == dict(
169
        layer_name=layer_name,
170
        num_layers=num_layers,
171
        norm_name=norm_name,
172
        activation=activation,
173
        filters=ch,
174
        kernel_size=3,
175
        padding="same",
176
        name="res_block",
177
        trainable=True,
178
        dtype="float32",
179
    )
180
181
182
class TestIntDVF:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
183
    def test_forward(self):
184
        """
185
        Test output shape and config.
186
        """
187
188
        fixed_image_size = (8, 9, 10)
189
        input_shape = (2, *fixed_image_size, 3)
190
191
        int_layer = layer.IntDVF(fixed_image_size=fixed_image_size)
192
193
        inputs = tf.ones(shape=input_shape)
194
        outputs = int_layer(inputs)
195
        assert outputs.shape == input_shape
196
197
        config = int_layer.get_config()
198
        assert config == dict(
199
            fixed_image_size=fixed_image_size,
200
            num_steps=7,
201
            name="int_dvf",
202
            trainable=True,
203
            dtype="float32",
204
        )
205
206
    def test_err(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
207
        with pytest.raises(AssertionError):
208
            layer.IntDVF(fixed_image_size=(2, 3))
209
210
211
def test_local_net_residual3d_block():
212
    """
213
    Test the layer.LocalNetResidual3dBlock class's,
214
    default attributes and call() function.
215
    """
216
217
    # Test __init__()
218
    conv3d_block = layer.LocalNetResidual3dBlock(8)
219
220
    assert conv3d_block._conv3d.kernel_size == (3, 3, 3)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
221
    assert conv3d_block._conv3d.strides == (1, 1, 1)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
222
    assert conv3d_block._conv3d.padding == "same"
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
223
    assert conv3d_block._conv3d.use_bias is False
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
224
225
226
def test_local_net_upsample_resnet_block():
227
    """
228
    Test the layer.LocalNetUpSampleResnetBlock class,
229
    its default attributes and its call() function.
230
    """
231
    batch_size = 5
232
    channels = 4
233
    input_size = (32, 32, 16)
234
    output_size = (64, 64, 32)
235
236
    nonskip_tensor_size = (batch_size,) + input_size + (channels,)
237
    skip_tensor_size = (batch_size,) + output_size + (channels,)
238
239
    # Test __init__() and build()
240
    model = layer.LocalNetUpSampleResnetBlock(8)
241
    model.build([nonskip_tensor_size, skip_tensor_size])
242
243
    assert model._filters == 8
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _filters was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
244
    assert model._use_additive_upsampling is True
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _use_additive_upsampling was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
245
246
    assert isinstance(model._deconv3d_block, layer.Deconv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _deconv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
247
    assert isinstance(model._conv3d_block, layer.Conv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
248
    assert isinstance(model._residual_block, layer.LocalNetResidual3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _residual_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
249
250
251
class TestResizeCPTransform:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
252
    @pytest.mark.parametrize(
253
        "parameter,cp_spacing", [((8, 8, 8), 8), ((8, 24, 16), (8, 24, 16))]
254
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
255
    def test_attributes(self, parameter, cp_spacing):
256
        model = layer.ResizeCPTransform(cp_spacing)
257
258
        if isinstance(cp_spacing, int):
259
            cp_spacing = [cp_spacing] * 3
260
        assert list(model.cp_spacing) == list(parameter)
261
        assert model.kernel_sigma == [0.44 * cp for cp in cp_spacing]
262
263
    @pytest.mark.parametrize(
264
        "input_size,output_size,cp_spacing",
265
        [
266
            ((1, 8, 8, 8, 3), (12, 8, 12), (8, 16, 8)),
267
            ((1, 8, 8, 8, 3), (12, 12, 12), 8),
268
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
269
    )
270
    def test_build(self, input_size, output_size, cp_spacing):
271
        model = layer.ResizeCPTransform(cp_spacing)
272
        model.build(input_size)
273
274
        assert [a == b for a, b, in zip(model._output_shape, output_size)]
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _output_shape was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
275
276
    @pytest.mark.parametrize(
277
        "input_size,output_size,cp_spacing",
278
        [
279
            ((1, 68, 68, 68, 3), (1, 12, 8, 12, 3), (8, 16, 8)),
280
            ((1, 68, 68, 68, 3), (1, 12, 12, 12, 3), 8),
281
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
282
    )
283
    def test_call(self, input_size, output_size, cp_spacing):
284
        model = layer.ResizeCPTransform(cp_spacing)
285
        model.build(input_size)
286
287
        input = tf.random.normal(shape=input_size, dtype=tf.float32)
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in input.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
288
        output = model(input)
289
290
        assert output.shape == output_size
291
292
293
class TestBSplines3DTransform:
294
    """
295
    Test the layer.BSplines3DTransform class,
296
    its default attributes and its call() function.
297
    """
298
299
    @pytest.mark.parametrize(
300
        "input_size,cp",
301
        [((1, 8, 8, 8, 3), 8), ((1, 8, 8, 8, 3), (8, 16, 12))],
302
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
303
    def test_init(self, input_size, cp):
304
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
305
306
        if isinstance(cp, int):
307
            cp = (cp, cp, cp)
308
309
        assert model.cp_spacing == cp
310
311
    @pytest.mark.parametrize(
312
        "input_size,cp",
313
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
314
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
315
    def generate_filter_coefficients(self, cp_spacing):
316
317
        b = {
318
            0: lambda u: np.float64((1 - u) ** 3 / 6),
319
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
320
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
321
            3: lambda u: np.float64(u ** 3 / 6),
322
        }
323
324
        filters = np.zeros(
325
            (
326
                4 * cp_spacing[0],
327
                4 * cp_spacing[1],
328
                4 * cp_spacing[2],
329
                3,
330
                3,
331
            ),
332
            dtype=np.float32,
333
        )
334
335
        for u in range(cp_spacing[0]):
0 ignored issues
show
unused-code introduced by
Too many nested blocks (7/5)
Loading history...
336
            for v in range(cp_spacing[1]):
337
                for w in range(cp_spacing[2]):
338
                    for x in range(4):
339
                        for y in range(4):
340
                            for z in range(4):
341
                                for it_dim in range(3):
342
                                    u_norm = 1 - (u + 0.5) / cp_spacing[0]
343
                                    v_norm = 1 - (v + 0.5) / cp_spacing[1]
344
                                    w_norm = 1 - (w + 0.5) / cp_spacing[2]
345
                                    filters[
346
                                        x * cp_spacing[0] + u,
347
                                        y * cp_spacing[1] + v,
348
                                        z * cp_spacing[2] + w,
349
                                        it_dim,
350
                                        it_dim,
351
                                    ] = (
352
                                        b[x](u_norm) * b[y](v_norm) * b[z](w_norm)
353
                                    )
354
        return filters
355
356
    @pytest.mark.parametrize(
357
        "input_size,cp",
358
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
359
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
360
    def test_build(self, input_size, cp):
361
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
362
363
        model.build(input_size)
364
        assert model.filter.shape == (
365
            4 * cp[0],
366
            4 * cp[1],
367
            4 * cp[2],
368
            3,
369
            3,
370
        )
371
372
    @pytest.mark.parametrize(
373
        "input_size,cp",
374
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
375
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
376
    def test_coefficients(self, input_size, cp):
377
378
        filters = self.generate_filter_coefficients(cp_spacing=cp)
379
380
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
381
        model.build(input_size)
382
383
        assert np.allclose(filters, model.filter.numpy(), atol=1e-8)
384
385
    @pytest.mark.parametrize(
386
        "input_size,cp",
387
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
388
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
389
    def test_interpolation(self, input_size, cp):
390
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
391
        model.build(input_size)
392
393
        vol_shape = input_size[1:-1]
394
        num_cp = (
395
            [input_size[0]]
396
            + [int(np.ceil(isize / cpsize) + 3) for isize, cpsize in zip(vol_shape, cp)]
397
            + [input_size[-1]]
398
        )
399
400
        field = tf.random.normal(shape=num_cp, dtype=tf.float32)
401
402
        ddf = model.call(field)
403
        assert ddf.shape == input_size
404
405
406
class TestResize3d:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
407
    @pytest.mark.parametrize(
408
        ("input_shape", "resize_shape", "output_shape"),
409
        [
410
            ((1, 2, 3), (3, 4, 5), (3, 4, 5)),
411
            ((2, 1, 2, 3), (3, 4, 5), (2, 3, 4, 5)),
412
            ((2, 1, 2, 3, 1), (3, 4, 5), (2, 3, 4, 5, 1)),
413
            ((2, 1, 2, 3, 6), (3, 4, 5), (2, 3, 4, 5, 6)),
414
            ((1, 2, 3), (1, 2, 3), (1, 2, 3)),
415
            ((2, 1, 2, 3), (1, 2, 3), (2, 1, 2, 3)),
416
            ((2, 1, 2, 3, 1), (1, 2, 3), (2, 1, 2, 3, 1)),
417
            ((2, 1, 2, 3, 6), (1, 2, 3), (2, 1, 2, 3, 6)),
418
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
419
    )
420
    def test_forward(self, input_shape, resize_shape, output_shape):
421
        inputs = tf.ones(shape=input_shape)
422
        outputs = layer.Resize3d(shape=resize_shape)(inputs)
423
        assert outputs.shape == output_shape
424
425
    def test_get_config(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
426
        resize = layer.Resize3d(shape=(2, 3, 4))
427
        config = resize.get_config()
428
        assert config == dict(
429
            shape=(2, 3, 4),
430
            method=tf.image.ResizeMethod.BILINEAR,
431
            name="resize3d",
432
            trainable=True,
433
            dtype="float32",
434
        )
435
436
    def test_shape_err(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
437
        with pytest.raises(AssertionError):
438
            layer.Resize3d(shape=(2, 3))
439
440
    def test_image_shape_err(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
441
        with pytest.raises(ValueError) as err_info:
442
            resize = layer.Resize3d(shape=(2, 3, 4))
443
            resize(tf.ones(1, 1))
444
        assert "Resize3d takes input image of dimension 3 or 4 or 5" in str(
445
            err_info.value
446
        )
447