Passed
Pull Request — main (#673)
by Yunguan
03:13
created

TestPyramidCombination.test_2d()   A

Complexity

Conditions 1

Size

Total Lines 22
Code Lines 14

Duplication

Lines 22
Ratio 100 %

Importance

Changes 0
Metric Value
eloc 14
dl 22
loc 22
rs 9.7
c 0
b 0
f 0
cc 1
nop 1
1
"""
2
Tests for deepreg/model/layer_util.py in
3
pytest style
4
"""
5
from test.unit.util import is_equal_tf
6
from typing import Tuple, Union
7
8
import numpy as np
9
import pytest
10
import tensorflow as tf
11
12
import deepreg.model.layer_util as layer_util
13
14
15
def test_get_reference_grid():
16
    """
17
    Test get_reference_grid by confirming that it generates
18
    a sample grid test case to is_equal_tf's tolerance level.
19
    """
20
    want = tf.constant(
21
        np.array(
22
            [[[[0, 0, 0], [0, 0, 1], [0, 0, 2]], [[0, 1, 0], [0, 1, 1], [0, 1, 2]]]],
23
            dtype=np.float32,
24
        )
25
    )
26
    get = layer_util.get_reference_grid(grid_size=[1, 2, 3])
27
    assert is_equal_tf(want, get)
28
29
30
def test_get_n_bits_combinations():
31
    """
32
    Test get_n_bits_combinations by confirming that it generates
33
    appropriate solutions for 1D, 2D, and 3D cases.
34
    """
35
    # Check n=1 - Pass
36
    assert layer_util.get_n_bits_combinations(1) == [[0], [1]]
37
    # Check n=2 - Pass
38
    assert layer_util.get_n_bits_combinations(2) == [[0, 0], [0, 1], [1, 0], [1, 1]]
39
40
    # Check n=3 - Pass
41
    assert layer_util.get_n_bits_combinations(3) == [
42
        [0, 0, 0],
43
        [0, 0, 1],
44
        [0, 1, 0],
45
        [0, 1, 1],
46
        [1, 0, 0],
47
        [1, 0, 1],
48
        [1, 1, 0],
49
        [1, 1, 1],
50
    ]
51
52
53
class TestPyramidCombination:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
54 View Code Duplication
    def test_1d(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
introduced by
Missing function or method docstring
Loading history...
55
        weights = tf.constant(np.array([[0.2]], dtype=np.float32))
56
        values = tf.constant(np.array([[1], [2]], dtype=np.float32))
57
58
        # expected = 1 * 0.2 + 2 * 2
59
        expected = tf.constant(np.array([1.8], dtype=np.float32))
60
        got = layer_util.pyramid_combination(
61
            values=values, weight_floor=weights, weight_ceil=1 - weights
62
        )
63
        assert is_equal_tf(got, expected)
64
65 View Code Duplication
    def test_2d(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
introduced by
Missing function or method docstring
Loading history...
66
        weights = tf.constant(np.array([[0.2], [0.3]], dtype=np.float32))
67
        values = tf.constant(
68
            np.array(
69
                [
70
                    [1],  # value at corner (0, 0), weight = 0.2 * 0.3
71
                    [2],  # value at corner (0, 1), weight = 0.2 * 0.7
72
                    [3],  # value at corner (1, 0), weight = 0.8 * 0.3
73
                    [4],  # value at corner (1, 1), weight = 0.8 * 0.7
74
                ],
75
                dtype=np.float32,
76
            )
77
        )
78
        # expected = 1 * 0.2 * 0.3
79
        #          + 2 * 0.2 * 0.7
80
        #          + 3 * 0.8 * 0.3
81
        #          + 4 * 0.8 * 0.7
82
        expected = tf.constant(np.array([3.3], dtype=np.float32))
83
        got = layer_util.pyramid_combination(
84
            values=values, weight_floor=weights, weight_ceil=1 - weights
85
        )
86
        assert is_equal_tf(got, expected)
87
88
    def test_error_dim(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
89
        weights = tf.constant(np.array([[[0.2]], [[0.2]]], dtype=np.float32))
90
        values = tf.constant(np.array([[1], [2]], dtype=np.float32))
91
        with pytest.raises(ValueError) as err_info:
92
            layer_util.pyramid_combination(
93
                values=values, weight_floor=weights, weight_ceil=1 - weights
94
            )
95
        assert (
96
            "In pyramid_combination, elements of values, weight_floor, "
97
            "and weight_ceil should have same dimension" in str(err_info.value)
98
        )
99
100
    def test_error_len(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
101
        weights = tf.constant(np.array([[0.2]], dtype=np.float32))
102
        values = tf.constant(np.array([[1]], dtype=np.float32))
103
        with pytest.raises(ValueError) as err_info:
104
            layer_util.pyramid_combination(
105
                values=values, weight_floor=weights, weight_ceil=1 - weights
106
            )
107
        assert (
108
            "In pyramid_combination, num_dim = len(weight_floor), "
109
            "len(values) must be 2 ** num_dim" in str(err_info.value)
110
        )
111
112
113
class TestLinearResample:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
114
    x_min, x_max = 0, 2
115
    y_min, y_max = 0, 2
116
    # vol are values on grid [0,2]x[0,2]
117
    # values on each point is 3x+y
118
    # shape = (1,3,3)
119
    vol = tf.constant(np.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]), dtype=tf.float32)
120
    # loc are some points, especially
121
    # shape = (1,4,3,2)
122
    loc = tf.constant(
123
        np.array(
124
            [
125
                [
126
                    [[0, 0], [0, 1], [1, 2]],  # boundary corners
127
                    [[0.4, 0], [0.5, 2], [2, 1.7]],  # boundary edge
128
                    [[-0.4, 0.7], [0, 3], [2, 3]],  # outside boundary
129
                    [[0.4, 0.7], [1, 1], [0.6, 0.3]],  # internal
130
                ]
131
            ]
132
        ),
133
        dtype=tf.float32,
134
    )
135
136
    @pytest.mark.parametrize("channel", [0, 1, 2])
137
    def test_repeat_extrapolation(self, channel):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
138
        x = self.loc[..., 0]
139
        y = self.loc[..., 1]
140
        x = tf.clip_by_value(x, self.x_min, self.x_max)
141
        y = tf.clip_by_value(y, self.y_min, self.y_max)
142
        expected = 3 * x + y
143
144
        vol = self.vol
145
        if channel > 0:
146
            vol = tf.repeat(vol[..., None], channel, axis=-1)
147
            expected = tf.repeat(expected[..., None], channel, axis=-1)
148
149
        got = layer_util.resample(vol=vol, loc=self.loc, zero_boundary=False)
150
        assert is_equal_tf(expected, got)
151
152
    @pytest.mark.parametrize("channel", [0, 1, 2])
153
    def test_repeat_zero_bound(self, channel):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
154
        x = self.loc[..., 0]
155
        y = self.loc[..., 1]
156
        expected = 3 * x + y
157
        expected = (
158
            expected
159
            * tf.cast(x > self.x_min, tf.float32)
160
            * tf.cast(x <= self.x_max, tf.float32)
161
        )
162
        expected = (
163
            expected
164
            * tf.cast(y > self.y_min, tf.float32)
165
            * tf.cast(y <= self.y_max, tf.float32)
166
        )
167
168
        vol = self.vol
169
        if channel > 0:
170
            vol = tf.repeat(vol[..., None], channel, axis=-1)
171
            expected = tf.repeat(expected[..., None], channel, axis=-1)
172
173
        got = layer_util.resample(vol=vol, loc=self.loc, zero_boundary=True)
174
        assert is_equal_tf(expected, got)
175
176
    def test_shape_error(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
177
        vol = tf.constant(np.array([[0]], dtype=np.float32))  # shape = [1,1]
178
        loc = tf.constant(np.array([[0, 0], [0, 0]], dtype=np.float32))  # shape = [2,2]
179
        with pytest.raises(ValueError) as err_info:
180
            layer_util.resample(vol=vol, loc=loc)
181
        assert "vol shape inconsistent with loc" in str(err_info.value)
182
183
    def test_interpolation_error(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
184
        interpolation = "nearest"
185
        vol = tf.constant(np.array([[0]], dtype=np.float32))  # shape = [1,1]
186
        loc = tf.constant(np.array([[0, 0], [0, 0]], dtype=np.float32))  # shape = [2,2]
187
        with pytest.raises(ValueError) as err_info:
188
            layer_util.resample(vol=vol, loc=loc, interpolation=interpolation)
189
        assert "resample supports only linear interpolation" in str(err_info.value)
190
191
192
class TestWarpGrid:
193
    """
194
    Test warp_grid by confirming that it generates
195
    appropriate solutions for simple precomputed cases.
196
    """
197
198
    grid = tf.constant(
199
        np.array(
200
            [[[[0, 0, 0], [0, 0, 1], [0, 0, 2]], [[0, 1, 0], [0, 1, 1], [0, 1, 2]]]],
201
            dtype=np.float32,
202
        )
203
    )  # shape = (1, 2, 3, 3)
204
205
    def test_identical(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
206
        theta = tf.constant(np.eye(4, 3).reshape((1, 4, 3)), dtype=tf.float32)
207
        expected = self.grid[None, ...]  # shape = (1, 1, 2, 3, 3)
208
        got = layer_util.warp_grid(grid=self.grid, theta=theta)
209
        assert is_equal_tf(got, expected)
210
211
    def test_non_identical(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
212
        theta = tf.constant(
213
            np.array(
214
                [
215
                    [
216
                        [0.86, 0.75, 0.48],
217
                        [0.07, 0.98, 0.01],
218
                        [0.72, 0.52, 0.97],
219
                        [0.12, 0.4, 0.04],
220
                    ]
221
                ],
222
                dtype=np.float32,
223
            )
224
        )  # shape = (1, 4, 3)
225
        expected = tf.constant(
226
            np.array(
227
                [
228
                    [
229
                        [
230
                            [[0.12, 0.4, 0.04], [0.84, 0.92, 1.01], [1.56, 1.44, 1.98]],
231
                            [[0.19, 1.38, 0.05], [0.91, 1.9, 1.02], [1.63, 2.42, 1.99]],
232
                        ]
233
                    ]
234
                ],
235
                dtype=np.float32,
236
            )
237
        )  # shape = (1, 1, 2, 3, 3)
238
        got = layer_util.warp_grid(grid=self.grid, theta=theta)
239
        assert is_equal_tf(got, expected)
240
241
242
class TestGaussianFilter3D:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
243
    @pytest.mark.parametrize(
244
        "kernel_sigma, kernel_size",
245
        [
246
            ((1, 1, 1), (3, 3, 3, 3, 3)),
247
            ((2, 2, 2), (7, 7, 7, 3, 3)),
248
            ((5, 5, 5), (15, 15, 15, 3, 3)),
249
            (1, (3, 3, 3, 3, 3)),
250
            (2, (7, 7, 7, 3, 3)),
251
            (5, (15, 15, 15, 3, 3)),
252
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
253
    )
254
    def test_kernel_size(self, kernel_sigma, kernel_size):
255
        filter = layer_util.gaussian_filter_3d(kernel_sigma)
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in filter.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
256
        assert filter.shape == kernel_size
257
258
    @pytest.mark.parametrize(
259
        "kernel_sigma",
260
        [(1, 1, 1), (2, 2, 2), (5, 5, 5)],
261
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
262
    def test_sum(self, kernel_sigma):
263
        filter = layer_util.gaussian_filter_3d(kernel_sigma)
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in filter.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
264
        assert np.allclose(np.sum(filter), 3, atol=1e-3)
265
266
267
class TestDeconvOutputPadding:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
268
    @pytest.mark.parametrize(
269
        ("input_shape", "output_shape", "kernel_size", "stride", "padding", "expected"),
270
        [
271
            (5, 5, 3, 1, "same", 0),
272
            (5, 7, 3, 1, "valid", 0),
273
            (5, 3, 3, 1, "full", 0),
274
            (5, 6, 3, 1, "same", 1),
275
            (5, 8, 3, 1, "valid", 1),
276
            (5, 4, 3, 1, "full", 1),
277
            (5, 9, 3, 2, "same", 0),
278
            (5, 11, 3, 2, "valid", 0),
279
            (5, 7, 3, 2, "full", 0),
280
        ],
281
    )
282
    def test_1d(
283
        self,
284
        input_shape: int,
285
        output_shape: int,
286
        kernel_size: int,
287
        stride: int,
288
        padding: str,
289
        expected: int,
290
    ):
291
        """
292
        Test _deconv_output_padding by verifying output
293
294
        :param input_shape: shape of Conv3DTranspose input tensor
295
        :param output_shape: shape of Conv3DTranspose output tensor
296
        :param kernel_size: kernel size of Conv3DTranspose layer
297
        :param stride: stride of Conv3DTranspose layer
298
        :param padding: padding of Conv3DTranspose layer
299
        :param expected: expected output padding for Conv3DTranspose layer
300
        """
301
        got = layer_util._deconv_output_padding(
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _deconv_output_padding was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
302
            input_shape, output_shape, kernel_size, stride, padding
303
        )
304
        assert got == expected
305
306
    def test_1d_err(self):
307
        """Test _deconv_output_padding err raising."""
308
        with pytest.raises(ValueError) as err_info:
309
            layer_util._deconv_output_padding(5, 5, 3, 1, "x")
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _deconv_output_padding was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
310
        assert "Unknown padding" in str(err_info.value)
311
312
    @pytest.mark.parametrize(
313
        ("input_shape", "output_shape", "kernel_size", "stride", "padding", "expected"),
314
        [
315
            (5, 9, 3, 2, "same", 0),
316
            ((5, 5), (9, 10), 3, 2, "same", (0, 1)),
317
            ((5, 5, 6), (9, 10, 12), 3, 2, "same", (0, 1, 1)),
318
            ((5, 5), (9, 10), (3, 3), 2, "same", (0, 1)),
319
            ((5, 5), (9, 10), 3, (2, 2), "same", (0, 1)),
320
            ((5, 5), (9, 10), (3, 4), 2, "same", (0, 2)),
321
        ],
322
    )
323
    def test_n_dim(
324
        self,
325
        input_shape: Union[Tuple[int], int],
326
        output_shape: Union[Tuple[int], int],
327
        kernel_size: Union[Tuple[int], int],
328
        stride: Union[Tuple[int], int],
329
        padding: str,
330
        expected: Union[Tuple[int], int],
331
    ):
332
        """
333
        Test deconv_output_padding by verifying output
334
335
        :param input_shape: shape of Conv3DTranspose input tensor
336
        :param output_shape: shape of Conv3DTranspose output tensor
337
        :param kernel_size: kernel size of Conv3DTranspose layer
338
        :param stride: stride of Conv3DTranspose layer
339
        :param padding: padding of Conv3DTranspose layer
340
        :param expected: expected output padding for Conv3DTranspose layer
341
        """
342
        got = layer_util.deconv_output_padding(
343
            input_shape, output_shape, kernel_size, stride, padding
344
        )
345
        assert got == expected
346