Passed
Pull Request — main (#656)
by Yunguan
03:05
created

test.unit.test_layer.test_res_block()   A

Complexity

Conditions 1

Size

Total Lines 40
Code Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 30
dl 0
loc 40
rs 9.16
c 0
b 0
f 0
cc 1
nop 4
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/layer
5
"""
6
from typing import Optional
7
8
import numpy as np
9
import pytest
10
import tensorflow as tf
11
12
import deepreg.model.layer as layer
13
14
15
@pytest.mark.parametrize(
16
    ("input_shape", "output_shape", "expected_shape"),
17
    [
18
        ((6, 7, 8), (12, 14, 16), (12, 14, 16)),
19
        ((6, 7, 8), (11, 13, 15), (11, 13, 15)),
20
        ((6, 7, 8), None, (12, 14, 16)),
21
    ],
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
22
)
23
def test_deconv3d(input_shape, output_shape: Optional[tuple], expected_shape: tuple):
24
    """
25
    Test output shapes and configs.
26
27
    :param input_shape: input shape for layer definition
28
    :param output_shape: output shape for layer definition
29
    :param expected_shape: expected output shape
30
    """
31
    batch_size = 5
32
33
    deconv3d = layer.Deconv3d(filters=3, strides=2, output_shape=output_shape)
34
35
    inputs = tf.ones(shape=(batch_size,) + input_shape + (1,))
36
    output = deconv3d(inputs)
37
38
    assert output.shape == (batch_size,) + expected_shape + (3,)
39
40
    config = deconv3d.get_config()
41
    assert config == dict(
42
        filters=3,
43
        output_shape=output_shape,
44
        kernel_size=3,
45
        strides=2,
46
        padding="same",
47
        name="deconv3d",
48
        trainable=True,
49
        dtype="float32",
50
    )
51
52
53
@pytest.mark.parametrize("layer_name", ["conv3d", "deconv3d"])
54
@pytest.mark.parametrize("norm_name", ["batch", "layer"])
55
@pytest.mark.parametrize("activation", ["relu", "elu"])
56
def test_norm_block(layer_name: str, norm_name: str, activation: str):
57
    """
58
    Test output shapes and configs.
59
60
    :param layer_name: layer_name for layer definition
61
    :param norm_name: norm_name for layer definition
62
    :param activation: activation for layer definition
63
    """
64
    input_size = (2, 3, 4, 5, 6)  # (batch, *shape, ch)
65
    norm_block = layer.NormBlock(
66
        layer_name=layer_name,
67
        norm_name=norm_name,
68
        activation=activation,
69
        filters=3,
70
        kernel_size=1,
71
        padding="same",
72
    )
73
    inputs = tf.ones(shape=input_size)
74
    outputs = norm_block(inputs)
75
    assert outputs.shape == input_size[:-1] + (3,)
76
77
    config = norm_block.get_config()
78
    assert config == dict(
79
        layer_name=layer_name,
80
        norm_name=norm_name,
81
        activation=activation,
82
        filters=3,
83
        kernel_size=1,
84
        padding="same",
85
        name="norm_block",
86
        trainable=True,
87
        dtype="float32",
88
    )
89
90
91
def test_upsample_resnet_block():
92
    """
93
    Test the layer.UpSampleResnetBlock class and its default attributes.
94
    """
95
    batch_size = 5
96
    channels = 4
97
    input_size = (32, 32, 16)
98
    output_size = (64, 64, 32)
99
100
    input_tensor_size = (batch_size,) + input_size + (channels,)
101
    skip_tensor_size = (batch_size,) + output_size + (channels // 2,)
102
103
    model = layer.UpSampleResnetBlock(8)
104
    model.build([input_tensor_size, skip_tensor_size])
105
106
    assert model._filters == 8
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _filters was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
107
    assert model._concat is False
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _concat was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
108
    assert isinstance(model._conv3d_block, layer.Conv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
109
    assert isinstance(model._residual_block, layer.ResidualConv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _residual_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
110
    assert isinstance(model._deconv3d_block, layer.Deconv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _deconv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
111
112
113
class TestWarping:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
114
    @pytest.mark.parametrize(
115
        ("moving_image_size", "fixed_image_size"),
116
        [
117
            ((1, 2, 3), (3, 4, 5)),
118
            ((1, 2, 3), (1, 2, 3)),
119
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
120
    )
121
    def test_forward(self, moving_image_size, fixed_image_size):
122
        batch_size = 2
123
        image = tf.ones(shape=(batch_size,) + moving_image_size)
124
        ddf = tf.ones(shape=(batch_size,) + fixed_image_size + (3,))
125
        outputs = layer.Warping(fixed_image_size=fixed_image_size)([ddf, image])
126
        assert outputs.shape == (batch_size, *fixed_image_size)
127
128
    def test_get_config(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
129
        warping = layer.Warping(fixed_image_size=(2, 3, 4))
130
        config = warping.get_config()
131
        assert config == dict(
132
            fixed_image_size=(2, 3, 4),
133
            name="warping",
134
            trainable=True,
135
            dtype="float32",
136
        )
137
138
139
@pytest.mark.parametrize("layer_name", ["conv3d", "deconv3d"])
140
@pytest.mark.parametrize("norm_name", ["batch", "layer"])
141
@pytest.mark.parametrize("activation", ["relu", "elu"])
142
@pytest.mark.parametrize("num_layers", [2, 3])
143
def test_res_block(layer_name: str, norm_name: str, activation: str, num_layers: int):
144
    """
145
    Test output shapes and configs.
146
147
    :param layer_name: layer_name for layer definition
148
    :param norm_name: norm_name for layer definition
149
    :param activation: activation for layer definition
150
    :param num_layers: number of blocks in res block
151
    """
152
    ch = 3
153
    input_size = (2, 3, 4, 5, ch)  # (batch, *shape, ch)
154
    res_block = layer.ResidualBlock(
155
        layer_name=layer_name,
156
        num_layers=num_layers,
157
        norm_name=norm_name,
158
        activation=activation,
159
        filters=ch,
160
        kernel_size=3,
161
        padding="same",
162
    )
163
    inputs = tf.ones(shape=input_size)
164
    outputs = res_block(inputs)
165
    assert outputs.shape == input_size[:-1] + (3,)
166
167
    config = res_block.get_config()
168
    assert config == dict(
169
        layer_name=layer_name,
170
        num_layers=num_layers,
171
        norm_name=norm_name,
172
        activation=activation,
173
        filters=ch,
174
        kernel_size=3,
175
        padding="same",
176
        name="res_block",
177
        trainable=True,
178
        dtype="float32",
179
    )
180
181
182
def test_init_dvf():
183
    """
184
    Test the layer.IntDVF class, its default attributes and its call() method.
185
    """
186
187
    batch_size = 5
188
    fixed_image_size = (32, 32, 16)
189
    ndims = len(fixed_image_size)
190
191
    model = layer.IntDVF(fixed_image_size)
192
193
    assert isinstance(model._warping, layer.Warping)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _warping was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
194
    assert model._num_steps == 7
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _num_steps was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
195
196
    inputs = np.ones((batch_size, *fixed_image_size, ndims))
197
    output = model.call(inputs)
198
    assert all(
199
        x == y
200
        for x, y in zip((batch_size,) + fixed_image_size + (ndims,), output.shape)
201
    )
202
203
204
def test_additive_upsampling():
205
    """
206
    Test the layer.AdditiveUpSampling class and its default attributes.
207
    """
208
    channels = 8
209
    batch_size = 5
210
    output_size = (32, 32, 16)
211
    input_size = (24, 24, 16)
212
213
    # Test __init__()
214
    model = layer.AdditiveUpSampling(output_size)
215
    assert model._stride == 2
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _stride was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
216
    assert model._output_shape == output_size
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _output_shape was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
217
218
    # Test call()
219
    inputs = np.ones(
220
        (batch_size, input_size[0], input_size[1], input_size[2], channels)
221
    )
222
    output = model(inputs)
223
    assert all(
224
        x == y
225
        for x, y in zip((batch_size,) + output_size + (channels / 2,), output.shape)
226
    )
227
228
    # Test the exceptions
229
    model = layer.AdditiveUpSampling(output_size, stride=3)
230
    with pytest.raises(ValueError):
231
        model(inputs)
232
233
234
def test_local_net_residual3d_block():
235
    """
236
    Test the layer.LocalNetResidual3dBlock class's,
237
    default attributes and call() function.
238
    """
239
240
    # Test __init__()
241
    conv3d_block = layer.LocalNetResidual3dBlock(8)
242
243
    assert conv3d_block._conv3d.kernel_size == (3, 3, 3)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
244
    assert conv3d_block._conv3d.strides == (1, 1, 1)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
245
    assert conv3d_block._conv3d.padding == "same"
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
246
    assert conv3d_block._conv3d.use_bias is False
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
247
248
249
def test_local_net_upsample_resnet_block():
250
    """
251
    Test the layer.LocalNetUpSampleResnetBlock class,
252
    its default attributes and its call() function.
253
    """
254
    batch_size = 5
255
    channels = 4
256
    input_size = (32, 32, 16)
257
    output_size = (64, 64, 32)
258
259
    nonskip_tensor_size = (batch_size,) + input_size + (channels,)
260
    skip_tensor_size = (batch_size,) + output_size + (channels,)
261
262
    # Test __init__() and build()
263
    model = layer.LocalNetUpSampleResnetBlock(8)
264
    model.build([nonskip_tensor_size, skip_tensor_size])
265
266
    assert model._filters == 8
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _filters was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
267
    assert model._use_additive_upsampling is True
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _use_additive_upsampling was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
268
269
    assert isinstance(model._deconv3d_block, layer.Deconv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _deconv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
270
    assert isinstance(model._additive_upsampling, layer.AdditiveUpSampling)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _additive_upsampling was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
271
    assert isinstance(model._conv3d_block, layer.Conv3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _conv3d_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
272
    assert isinstance(model._residual_block, layer.LocalNetResidual3dBlock)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _residual_block was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
273
274
275
class TestResizeCPTransform:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
276
    @pytest.mark.parametrize(
277
        "parameter,cp_spacing", [((8, 8, 8), 8), ((8, 24, 16), (8, 24, 16))]
278
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
279
    def test_attributes(self, parameter, cp_spacing):
280
        model = layer.ResizeCPTransform(cp_spacing)
281
282
        if isinstance(cp_spacing, int):
283
            cp_spacing = [cp_spacing] * 3
284
        assert list(model.cp_spacing) == list(parameter)
285
        assert model.kernel_sigma == [0.44 * cp for cp in cp_spacing]
286
287
    @pytest.mark.parametrize(
288
        "input_size,output_size,cp_spacing",
289
        [
290
            ((1, 8, 8, 8, 3), (12, 8, 12), (8, 16, 8)),
291
            ((1, 8, 8, 8, 3), (12, 12, 12), 8),
292
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
293
    )
294
    def test_build(self, input_size, output_size, cp_spacing):
295
        model = layer.ResizeCPTransform(cp_spacing)
296
        model.build(input_size)
297
298
        assert [a == b for a, b, in zip(model._output_shape, output_size)]
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _output_shape was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
299
300
    @pytest.mark.parametrize(
301
        "input_size,output_size,cp_spacing",
302
        [
303
            ((1, 68, 68, 68, 3), (1, 12, 8, 12, 3), (8, 16, 8)),
304
            ((1, 68, 68, 68, 3), (1, 12, 12, 12, 3), 8),
305
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
306
    )
307
    def test_call(self, input_size, output_size, cp_spacing):
308
        model = layer.ResizeCPTransform(cp_spacing)
309
        model.build(input_size)
310
311
        input = tf.random.normal(shape=input_size, dtype=tf.float32)
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in input.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
312
        output = model(input)
313
314
        assert output.shape == output_size
315
316
317
class TestBSplines3DTransform:
318
    """
319
    Test the layer.BSplines3DTransform class,
320
    its default attributes and its call() function.
321
    """
322
323
    @pytest.mark.parametrize(
324
        "input_size,cp",
325
        [((1, 8, 8, 8, 3), 8), ((1, 8, 8, 8, 3), (8, 16, 12))],
326
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
327
    def test_init(self, input_size, cp):
328
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
329
330
        if isinstance(cp, int):
331
            cp = (cp, cp, cp)
332
333
        assert model.cp_spacing == cp
334
335
    @pytest.mark.parametrize(
336
        "input_size,cp",
337
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
338
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
339
    def generate_filter_coefficients(self, cp_spacing):
340
341
        b = {
342
            0: lambda u: np.float64((1 - u) ** 3 / 6),
343
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
344
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
345
            3: lambda u: np.float64(u ** 3 / 6),
346
        }
347
348
        filters = np.zeros(
349
            (
350
                4 * cp_spacing[0],
351
                4 * cp_spacing[1],
352
                4 * cp_spacing[2],
353
                3,
354
                3,
355
            ),
356
            dtype=np.float32,
357
        )
358
359
        for u in range(cp_spacing[0]):
0 ignored issues
show
unused-code introduced by
Too many nested blocks (7/5)
Loading history...
360
            for v in range(cp_spacing[1]):
361
                for w in range(cp_spacing[2]):
362
                    for x in range(4):
363
                        for y in range(4):
364
                            for z in range(4):
365
                                for it_dim in range(3):
366
                                    u_norm = 1 - (u + 0.5) / cp_spacing[0]
367
                                    v_norm = 1 - (v + 0.5) / cp_spacing[1]
368
                                    w_norm = 1 - (w + 0.5) / cp_spacing[2]
369
                                    filters[
370
                                        x * cp_spacing[0] + u,
371
                                        y * cp_spacing[1] + v,
372
                                        z * cp_spacing[2] + w,
373
                                        it_dim,
374
                                        it_dim,
375
                                    ] = (
376
                                        b[x](u_norm) * b[y](v_norm) * b[z](w_norm)
377
                                    )
378
        return filters
379
380
    @pytest.mark.parametrize(
381
        "input_size,cp",
382
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
383
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
384
    def test_build(self, input_size, cp):
385
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
386
387
        model.build(input_size)
388
        assert model.filter.shape == (
389
            4 * cp[0],
390
            4 * cp[1],
391
            4 * cp[2],
392
            3,
393
            3,
394
        )
395
396
    @pytest.mark.parametrize(
397
        "input_size,cp",
398
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
399
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
400
    def test_coefficients(self, input_size, cp):
401
402
        filters = self.generate_filter_coefficients(cp_spacing=cp)
403
404
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
405
        model.build(input_size)
406
407
        assert np.allclose(filters, model.filter.numpy(), atol=1e-8)
408
409
    @pytest.mark.parametrize(
410
        "input_size,cp",
411
        [((1, 8, 8, 8, 3), (8, 8, 8)), ((1, 8, 8, 8, 3), (8, 16, 12))],
412
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
413
    def test_interpolation(self, input_size, cp):
414
        model = layer.BSplines3DTransform(cp, input_size[1:-1])
415
        model.build(input_size)
416
417
        vol_shape = input_size[1:-1]
418
        num_cp = (
419
            [input_size[0]]
420
            + [int(np.ceil(isize / cpsize) + 3) for isize, cpsize in zip(vol_shape, cp)]
421
            + [input_size[-1]]
422
        )
423
424
        field = tf.random.normal(shape=num_cp, dtype=tf.float32)
425
426
        ddf = model.call(field)
427
        assert ddf.shape == input_size
428
429
430
class TestResize3d:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
431
    @pytest.mark.parametrize(
432
        ("input_shape", "resize_shape", "output_shape"),
433
        [
434
            ((1, 2, 3), (3, 4, 5), (3, 4, 5)),
435
            ((2, 1, 2, 3), (3, 4, 5), (2, 3, 4, 5)),
436
            ((2, 1, 2, 3, 1), (3, 4, 5), (2, 3, 4, 5, 1)),
437
            ((2, 1, 2, 3, 6), (3, 4, 5), (2, 3, 4, 5, 6)),
438
            ((1, 2, 3), (1, 2, 3), (1, 2, 3)),
439
            ((2, 1, 2, 3), (1, 2, 3), (2, 1, 2, 3)),
440
            ((2, 1, 2, 3, 1), (1, 2, 3), (2, 1, 2, 3, 1)),
441
            ((2, 1, 2, 3, 6), (1, 2, 3), (2, 1, 2, 3, 6)),
442
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
443
    )
444
    def test_forward(self, input_shape, resize_shape, output_shape):
445
        inputs = tf.ones(shape=input_shape)
446
        outputs = layer.Resize3d(shape=resize_shape)(inputs)
447
        assert outputs.shape == output_shape
448
449
    def test_get_config(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
450
        resize = layer.Resize3d(shape=(2, 3, 4))
451
        config = resize.get_config()
452
        assert config == dict(
453
            shape=(2, 3, 4),
454
            method=tf.image.ResizeMethod.BILINEAR,
455
            name="resize3d",
456
            trainable=True,
457
            dtype="float32",
458
        )
459
460
    def test_shape_err(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
461
        with pytest.raises(AssertionError):
462
            layer.Resize3d(shape=(2, 3))
463
464
    def test_image_shape_err(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
465
        with pytest.raises(ValueError) as err_info:
466
            resize = layer.Resize3d(shape=(2, 3, 4))
467
            resize(tf.ones(1, 1))
468
        assert "Resize3d takes input image of dimension 3 or 4 or 5" in str(
469
            err_info.value
470
        )
471