Passed
Pull Request — main (#656)
by Yunguan
02:34
created

test.unit.test_layer.TestResize3d.test_forward()   A

Complexity

Conditions 1

Size

Total Lines 17
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 15
dl 0
loc 17
rs 9.65
c 0
b 0
f 0
cc 1
nop 4
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/layer
5
"""
6
from typing import Optional
7
8
import numpy as np
9
import pytest
10
import tensorflow as tf
11
12
import deepreg.model.layer as layer
13
14
15
@pytest.mark.parametrize(
16
    ("input_shape", "output_shape", "expected_shape"),
17
    [
18
        ((6, 7, 8), (12, 14, 16), (12, 14, 16)),
19
        ((6, 7, 8), (11, 13, 15), (11, 13, 15)),
20
        ((6, 7, 8), None, (12, 14, 16)),
21
    ],
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
22
)
23
def test_deconv3d(input_shape, output_shape: Optional[tuple], expected_shape: tuple):
24
    """
25
    Test output shapes and configs.
26
27
    :param input_shape: input shape for layer definition
28
    :param output_shape: output shape for layer definition
29
    :param expected_shape: expected output shape
30
    """
31
    batch_size = 5
32
33
    deconv3d = layer.Deconv3d(filters=3, strides=2, output_shape=output_shape)
34
35
    inputs = tf.ones(shape=(batch_size,) + input_shape + (1,))
36
    output = deconv3d(inputs)
37
38
    assert output.shape == (batch_size,) + expected_shape + (3,)
39
40
    config = deconv3d.get_config()
41
    assert config == dict(
42
        filters=3,
43
        output_shape=output_shape,
44
        kernel_size=3,
45
        strides=2,
46
        padding="same",
47
        name="deconv3d",
48
        trainable=True,
49
        dtype="float32",
50
    )
51
52
53
@pytest.mark.parametrize("layer_name", ["conv3d", "deconv3d"])
54
@pytest.mark.parametrize("norm_name", ["batch", "layer"])
55
@pytest.mark.parametrize("activation", ["relu", "elu"])
56
def test_norm_block(layer_name: str, norm_name: str, activation: str):
57
    """
58
    Test output shapes and configs.
59
60
    :param layer_name: layer_name for layer definition
61
    :param norm_name: norm_name for layer definition
62
    :param activation: activation for layer definition
63
    """
64
    input_size = (2, 3, 4, 5, 6)  # (batch, *shape, ch)
65
    norm_block = layer.NormBlock(
66
        layer_name=layer_name,
67
        norm_name=norm_name,
68
        activation=activation,
69
        filters=3,
70
        kernel_size=1,
71
        padding="same",
72
    )
73
    inputs = tf.ones(shape=input_size)
74
    outputs = norm_block(inputs)
75
    assert outputs.shape == input_size[:-1] + (3,)
76
77
    config = norm_block.get_config()
78
    assert config == dict(
79
        layer_name=layer_name,
80
        norm_name=norm_name,
81
        activation=activation,
82
        filters=3,
83
        kernel_size=1,
84
        padding="same",
85
        name="norm_block",
86
        trainable=True,
87
        dtype="float32",
88
    )
89
90
91
def test_upsample_resnet_block():
92
    """
93
    Test the layer.UpSampleResnetBlock class and its default attributes.
94
    """
95
    batch_size = 5
96
    channels = 4
97
    input_size = (32, 32, 16)
98
    output_size = (64, 64, 32)
99
100
    input_tensor_size = (batch_size,) + input_size + (channels,)
101
    skip_tensor_size = (batch_size,) + output_size + (channels // 2,)
102
103
    model = layer.UpSampleResnetBlock(8)
104
    model.build([input_tensor_size, skip_tensor_size])
105
106
    assert model._filters == 8
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _filters was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
107
    assert model._concat is False
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _concat was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
108
    assert isinstance(model._conv3d_block, layer.Conv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
109
    assert isinstance(model._residual_block, layer.ResidualConv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _residual_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
110
    assert isinstance(model._deconv3d_block, layer.Deconv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _deconv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
111
112
113
class TestWarping:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
114
    @pytest.mark.parametrize(
115
        ("moving_image_size", "fixed_image_size"),
116
        [
117
            ((1, 2, 3), (3, 4, 5)),
118
            ((1, 2, 3), (1, 2, 3)),
119
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
120
    )
121
    def test_forward(self, moving_image_size, fixed_image_size):
122
        batch_size = 2
123
        image = tf.ones(shape=(batch_size,) + moving_image_size)
124
        ddf = tf.ones(shape=(batch_size,) + fixed_image_size + (3,))
125
        outputs = layer.Warping(fixed_image_size=fixed_image_size)([ddf, image])
126
        assert outputs.shape == (batch_size, *fixed_image_size)
127
128
    def test_get_config(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
129
        warping = layer.Warping(fixed_image_size=(2, 3, 4))
130
        config = warping.get_config()
131
        assert config == dict(
132
            fixed_image_size=(2, 3, 4),
133
            name="warping",
134
            trainable=True,
135
            dtype="float32",
136
        )
137
138
139
@pytest.mark.parametrize("layer_name", ["conv3d", "deconv3d"])
140
@pytest.mark.parametrize("norm_name", ["batch", "layer"])
141
@pytest.mark.parametrize("activation", ["relu", "elu"])
142
@pytest.mark.parametrize("num_layers", [2, 3])
143
def test_res_block(layer_name: str, norm_name: str, activation: str, num_layers: int):
144
    """
145
    Test output shapes and configs.
146
147
    :param layer_name: layer_name for layer definition
148
    :param norm_name: norm_name for layer definition
149
    :param activation: activation for layer definition
150
    :param num_layers: number of blocks in res block
151
    """
152
    ch = 3
153
    input_size = (2, 3, 4, 5, ch)  # (batch, *shape, ch)
154
    res_block = layer.ResidualBlock(
155
        layer_name=layer_name,
156
        num_layers=num_layers,
157
        norm_name=norm_name,
158
        activation=activation,
159
        filters=ch,
160
        kernel_size=3,
161
        padding="same",
162
    )
163
    inputs = tf.ones(shape=input_size)
164
    outputs = res_block(inputs)
165
    assert outputs.shape == input_size[:-1] + (3,)
166
167
    config = res_block.get_config()
168
    assert config == dict(
169
        layer_name=layer_name,
170
        num_layers=num_layers,
171
        norm_name=norm_name,
172
        activation=activation,
173
        filters=ch,
174
        kernel_size=3,
175
        padding="same",
176
        name="res_block",
177
        trainable=True,
178
        dtype="float32",
179
    )
180
181
182
class TestIntDVF:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
183
    def test_forward(self):
184
        """
185
        Test output shape and config.
186
        """
187
188
        fixed_image_size = (8, 9, 10)
189
        input_shape = (2, *fixed_image_size, 3)
190
191
        int_layer = layer.IntDVF(fixed_image_size=fixed_image_size)
192
193
        inputs = tf.ones(shape=input_shape)
194
        outputs = int_layer(inputs)
195
        assert outputs.shape == input_shape
196
197
        config = int_layer.get_config()
198
        assert config == dict(
199
            fixed_image_size=fixed_image_size,
200
            num_steps=7,
201
            name="int_dvf",
202
            trainable=True,
203
            dtype="float32",
204
        )
205
206
    def test_err(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
207
        with pytest.raises(AssertionError):
208
            layer.IntDVF(fixed_image_size=(2, 3))
209
210
211
def test_additive_upsampling():
212
    """
213
    Test the layer.AdditiveUpSampling class and its default attributes.
214
    """
215
    channels = 8
216
    batch_size = 5
217
    output_size = (32, 32, 16)
218
    input_size = (24, 24, 16)
219
220
    # Test __init__()
221
    model = layer.AdditiveUpSampling(output_size)
222
    assert model._stride == 2
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _stride was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
223
    assert model._output_shape == output_size
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _output_shape was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
224
225
    # Test call()
226
    inputs = np.ones(
227
        (batch_size, input_size[0], input_size[1], input_size[2], channels)
228
    )
229
    output = model(inputs)
230
    assert all(
231
        x == y
232
        for x, y in zip((batch_size,) + output_size + (channels / 2,), output.shape)
233
    )
234
235
    # Test the exceptions
236
    model = layer.AdditiveUpSampling(output_size, stride=3)
237
    with pytest.raises(ValueError):
238
        model(inputs)
239
240
241
def test_local_net_residual3d_block():
242
    """
243
    Test the layer.LocalNetResidual3dBlock class's,
244
    default attributes and call() function.
245
    """
246
247
    # Test __init__()
248
    conv3d_block = layer.LocalNetResidual3dBlock(8)
249
250
    assert conv3d_block._conv3d.kernel_size == (3, 3, 3)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
251
    assert conv3d_block._conv3d.strides == (1, 1, 1)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
252
    assert conv3d_block._conv3d.padding == "same"
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
253
    assert conv3d_block._conv3d.use_bias is False
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
254
255
256
def test_local_net_upsample_resnet_block():
257
    """
258
    Test the layer.LocalNetUpSampleResnetBlock class,
259
    its default attributes and its call() function.
260
    """
261
    batch_size = 5
262
    channels = 4
263
    input_size = (32, 32, 16)
264
    output_size = (64, 64, 32)
265
266
    nonskip_tensor_size = (batch_size,) + input_size + (channels,)
267
    skip_tensor_size = (batch_size,) + output_size + (channels,)
268
269
    # Test __init__() and build()
270
    model = layer.LocalNetUpSampleResnetBlock(8)
271
    model.build([nonskip_tensor_size, skip_tensor_size])
272
273
    assert model._filters == 8
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _filters was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
274
    assert model._use_additive_upsampling is True
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _use_additive_upsampling was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
275
276
    assert isinstance(model._deconv3d_block, layer.Deconv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _deconv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
277
    assert isinstance(model._additive_upsampling, layer.AdditiveUpSampling)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _additive_upsampling was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
278
    assert isinstance(model._conv3d_block, layer.Conv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
279
    assert isinstance(model._residual_block, layer.LocalNetResidual3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _residual_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
280
281
282
class TestResizeCPTransform:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
283
    @pytest.mark.parametrize(
284
        "parameter,cp_spacing", [((8, 8, 8), 8), ((8, 24, 16), (8, 24, 16))]
285
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
286
    def test_attributes(self, parameter, cp_spacing):
287
        model = layer.ResizeCPTransform(cp_spacing)
288
289
        if isinstance(cp_spacing, int):
290
            cp_spacing = [cp_spacing] * 3
291
        assert list(model.cp_spacing) == list(parameter)
292
        assert model.kernel_sigma == [0.44 * cp for cp in cp_spacing]
293
294
    @pytest.mark.parametrize(
295
        "input_size,output_size,cp_spacing",
296
        [
297
            ((1, 8, 8, 8, 3), (12, 8, 12), (8, 16, 8)),
298
            ((1, 8, 8, 8, 3), (12, 12, 12), 8),
299
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
300
    )
301
    def test_build(self, input_size, output_size, cp_spacing):
302
        model = layer.ResizeCPTransform(cp_spacing)
303
        model.build(input_size)
304
305
        assert [a == b for a, b, in zip(model._output_shape, output_size)]
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _output_shape was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
306
307
    @pytest.mark.parametrize(
308
        "input_size,output_size,cp_spacing",
309
        [
310
            ((1, 68, 68, 68, 3), (1, 12, 8, 12, 3), (8, 16, 8)),
311
            ((1, 68, 68, 68, 3), (1, 12, 12, 12, 3), 8),
312
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
313
    )
314
    def test_call(self, input_size, output_size, cp_spacing):
315
        model = layer.ResizeCPTransform(cp_spacing)
316
        model.build(input_size)
317
318
        input = tf.random.normal(shape=input_size, dtype=tf.float32)
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in input.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
319
        output = model(input)
320
321
        assert output.shape == output_size
322
323
324
class TestBSplines3DTransform:
325
    """
326
    Test the layer.BSplines3DTransform class,
327
    its default attributes and its call() function.
328
    """
329
330
    @pytest.mark.parametrize(
331
        "input_size,cp",
332
        [((1, 8, 8, 8, 3), 8), ((1, 8, 8, 8, 3), (8, 16, 12))],
333
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
334
    def test_init(self, input_size, cp):
335
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
336
337
        if isinstance(cp, int):
338
            cp = (cp, cp, cp)
339
340
        assert model.cp_spacing == cp
341
342
    @pytest.mark.parametrize(
343
        "input_size,cp",
344
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
345
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
346
    def generate_filter_coefficients(self, cp_spacing):
347
348
        b = {
349
            0: lambda u: np.float64((1 - u) ** 3 / 6),
350
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
351
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
352
            3: lambda u: np.float64(u ** 3 / 6),
353
        }
354
355
        filters = np.zeros(
356
            (
357
                4 * cp_spacing[0],
358
                4 * cp_spacing[1],
359
                4 * cp_spacing[2],
360
                3,
361
                3,
362
            ),
363
            dtype=np.float32,
364
        )
365
366
        for u in range(cp_spacing[0]):
0 ignored issues
show
unused-code introduced by
Too many nested blocks (7/5)
Loading history...
367
            for v in range(cp_spacing[1]):
368
                for w in range(cp_spacing[2]):
369
                    for x in range(4):
370
                        for y in range(4):
371
                            for z in range(4):
372
                                for it_dim in range(3):
373
                                    u_norm = 1 - (u + 0.5) / cp_spacing[0]
374
                                    v_norm = 1 - (v + 0.5) / cp_spacing[1]
375
                                    w_norm = 1 - (w + 0.5) / cp_spacing[2]
376
                                    filters[
377
                                        x * cp_spacing[0] + u,
378
                                        y * cp_spacing[1] + v,
379
                                        z * cp_spacing[2] + w,
380
                                        it_dim,
381
                                        it_dim,
382
                                    ] = (
383
                                        b[x](u_norm) * b[y](v_norm) * b[z](w_norm)
384
                                    )
385
        return filters
386
387
    @pytest.mark.parametrize(
388
        "input_size,cp",
389
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
390
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
391
    def test_build(self, input_size, cp):
392
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
393
394
        model.build(input_size)
395
        assert model.filter.shape == (
396
            4 * cp[0],
397
            4 * cp[1],
398
            4 * cp[2],
399
            3,
400
            3,
401
        )
402
403
    @pytest.mark.parametrize(
404
        "input_size,cp",
405
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
406
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
407
    def test_coefficients(self, input_size, cp):
408
409
        filters = self.generate_filter_coefficients(cp_spacing=cp)
410
411
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
412
        model.build(input_size)
413
414
        assert np.allclose(filters, model.filter.numpy(), atol=1e-8)
415
416
    @pytest.mark.parametrize(
417
        "input_size,cp",
418
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
419
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
420
    def test_interpolation(self, input_size, cp):
421
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
422
        model.build(input_size)
423
424
        vol_shape = input_size[1:-1]
425
        num_cp = (
426
            [input_size[0]]
427
            + [int(np.ceil(isize / cpsize) + 3) for isize, cpsize in zip(vol_shape, cp)]
428
            + [input_size[-1]]
429
        )
430
431
        field = tf.random.normal(shape=num_cp, dtype=tf.float32)
432
433
        ddf = model.call(field)
434
        assert ddf.shape == input_size
435
436
437
class TestResize3d:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
438
    @pytest.mark.parametrize(
439
        ("input_shape", "resize_shape", "output_shape"),
440
        [
441
            ((1, 2, 3), (3, 4, 5), (3, 4, 5)),
442
            ((2, 1, 2, 3), (3, 4, 5), (2, 3, 4, 5)),
443
            ((2, 1, 2, 3, 1), (3, 4, 5), (2, 3, 4, 5, 1)),
444
            ((2, 1, 2, 3, 6), (3, 4, 5), (2, 3, 4, 5, 6)),
445
            ((1, 2, 3), (1, 2, 3), (1, 2, 3)),
446
            ((2, 1, 2, 3), (1, 2, 3), (2, 1, 2, 3)),
447
            ((2, 1, 2, 3, 1), (1, 2, 3), (2, 1, 2, 3, 1)),
448
            ((2, 1, 2, 3, 6), (1, 2, 3), (2, 1, 2, 3, 6)),
449
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
450
    )
451
    def test_forward(self, input_shape, resize_shape, output_shape):
452
        inputs = tf.ones(shape=input_shape)
453
        outputs = layer.Resize3d(shape=resize_shape)(inputs)
454
        assert outputs.shape == output_shape
455
456
    def test_get_config(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
457
        resize = layer.Resize3d(shape=(2, 3, 4))
458
        config = resize.get_config()
459
        assert config == dict(
460
            shape=(2, 3, 4),
461
            method=tf.image.ResizeMethod.BILINEAR,
462
            name="resize3d",
463
            trainable=True,
464
            dtype="float32",
465
        )
466
467
    def test_shape_err(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
468
        with pytest.raises(AssertionError):
469
            layer.Resize3d(shape=(2, 3))
470
471
    def test_image_shape_err(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
472
        with pytest.raises(ValueError) as err_info:
473
            resize = layer.Resize3d(shape=(2, 3, 4))
474
            resize(tf.ones(1, 1))
475
        assert "Resize3d takes input image of dimension 3 or 4 or 5" in str(
476
            err_info.value
477
        )
478