Passed
Pull Request — main (#656)
by Yunguan
03:08
created

test.unit.test_layer_util.test_warp_image_ddf()   A

Complexity

Conditions 4

Size

Total Lines 47
Code Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 30
dl 0
loc 47
rs 9.16
c 0
b 0
f 0
cc 4
nop 0
1
"""
2
Tests for deepreg/model/layer_util.py in
3
pytest style
4
"""
5
from test.unit.util import is_equal_tf
6
7
import numpy as np
8
import pytest
9
import tensorflow as tf
10
11
import deepreg.model.layer_util as layer_util
12
13
14
def test_get_reference_grid():
15
    """
16
    Test get_reference_grid by confirming that it generates
17
    a sample grid test case to is_equal_tf's tolerance level.
18
    """
19
    want = tf.constant(
20
        np.array(
21
            [[[[0, 0, 0], [0, 0, 1], [0, 0, 2]], [[0, 1, 0], [0, 1, 1], [0, 1, 2]]]],
22
            dtype=np.float32,
23
        )
24
    )
25
    get = layer_util.get_reference_grid(grid_size=[1, 2, 3])
26
    assert is_equal_tf(want, get)
27
28
29
def test_get_n_bits_combinations():
30
    """
31
    Test get_n_bits_combinations by confirming that it generates
32
    appropriate solutions for 1D, 2D, and 3D cases.
33
    """
34
    # Check n=1 - Pass
35
    assert layer_util.get_n_bits_combinations(1) == [[0], [1]]
36
    # Check n=2 - Pass
37
    assert layer_util.get_n_bits_combinations(2) == [[0, 0], [0, 1], [1, 0], [1, 1]]
38
39
    # Check n=3 - Pass
40
    assert layer_util.get_n_bits_combinations(3) == [
41
        [0, 0, 0],
42
        [0, 0, 1],
43
        [0, 1, 0],
44
        [0, 1, 1],
45
        [1, 0, 0],
46
        [1, 0, 1],
47
        [1, 1, 0],
48
        [1, 1, 1],
49
    ]
50
51
52
class TestPyramidCombination:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
53 View Code Duplication
    def test_1d(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
introduced by
Missing function or method docstring
Loading history...
54
        weights = tf.constant(np.array([[0.2]], dtype=np.float32))
55
        values = tf.constant(np.array([[1], [2]], dtype=np.float32))
56
57
        # expected = 1 * 0.2 + 2 * 2
58
        expected = tf.constant(np.array([1.8], dtype=np.float32))
59
        got = layer_util.pyramid_combination(
60
            values=values, weight_floor=weights, weight_ceil=1 - weights
61
        )
62
        assert is_equal_tf(got, expected)
63
64 View Code Duplication
    def test_2d(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
introduced by
Missing function or method docstring
Loading history...
65
        weights = tf.constant(np.array([[0.2], [0.3]], dtype=np.float32))
66
        values = tf.constant(
67
            np.array(
68
                [
69
                    [1],  # value at corner (0, 0), weight = 0.2 * 0.3
70
                    [2],  # value at corner (0, 1), weight = 0.2 * 0.7
71
                    [3],  # value at corner (1, 0), weight = 0.8 * 0.3
72
                    [4],  # value at corner (1, 1), weight = 0.8 * 0.7
73
                ],
74
                dtype=np.float32,
75
            )
76
        )
77
        # expected = 1 * 0.2 * 0.3
78
        #          + 2 * 0.2 * 0.7
79
        #          + 3 * 0.8 * 0.3
80
        #          + 4 * 0.8 * 0.7
81
        expected = tf.constant(np.array([3.3], dtype=np.float32))
82
        got = layer_util.pyramid_combination(
83
            values=values, weight_floor=weights, weight_ceil=1 - weights
84
        )
85
        assert is_equal_tf(got, expected)
86
87
    def test_error_dim(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
88
        weights = tf.constant(np.array([[[0.2]], [[0.2]]], dtype=np.float32))
89
        values = tf.constant(np.array([[1], [2]], dtype=np.float32))
90
        with pytest.raises(ValueError) as err_info:
91
            layer_util.pyramid_combination(
92
                values=values, weight_floor=weights, weight_ceil=1 - weights
93
            )
94
        assert (
95
            "In pyramid_combination, elements of values, weight_floor, "
96
            "and weight_ceil should have same dimension" in str(err_info.value)
97
        )
98
99
    def test_error_len(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
100
        weights = tf.constant(np.array([[0.2]], dtype=np.float32))
101
        values = tf.constant(np.array([[1]], dtype=np.float32))
102
        with pytest.raises(ValueError) as err_info:
103
            layer_util.pyramid_combination(
104
                values=values, weight_floor=weights, weight_ceil=1 - weights
105
            )
106
        assert (
107
            "In pyramid_combination, num_dim = len(weight_floor), "
108
            "len(values) must be 2 ** num_dim" in str(err_info.value)
109
        )
110
111
112
class TestLinearResample:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
113
    x_min, x_max = 0, 2
114
    y_min, y_max = 0, 2
115
    # vol are values on grid [0,2]x[0,2]
116
    # values on each point is 3x+y
117
    # shape = (1,3,3)
118
    vol = tf.constant(np.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]), dtype=tf.float32)
119
    # loc are some points, especially
120
    # shape = (1,4,3,2)
121
    loc = tf.constant(
122
        np.array(
123
            [
124
                [
125
                    [[0, 0], [0, 1], [1, 2]],  # boundary corners
126
                    [[0.4, 0], [0.5, 2], [2, 1.7]],  # boundary edge
127
                    [[-0.4, 0.7], [0, 3], [2, 3]],  # outside boundary
128
                    [[0.4, 0.7], [1, 1], [0.6, 0.3]],  # internal
129
                ]
130
            ]
131
        ),
132
        dtype=tf.float32,
133
    )
134
135
    @pytest.mark.parametrize("channel", [0, 1, 2])
136
    def test_repeat_extrapolation(self, channel):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
137
        x = self.loc[..., 0]
138
        y = self.loc[..., 1]
139
        x = tf.clip_by_value(x, self.x_min, self.x_max)
140
        y = tf.clip_by_value(y, self.y_min, self.y_max)
141
        expected = 3 * x + y
142
143
        vol = self.vol
144
        if channel > 0:
145
            vol = tf.repeat(vol[..., None], channel, axis=-1)
146
            expected = tf.repeat(expected[..., None], channel, axis=-1)
147
148
        got = layer_util.resample(vol=vol, loc=self.loc, zero_boundary=False)
149
        assert is_equal_tf(expected, got)
150
151
    @pytest.mark.parametrize("channel", [0, 1, 2])
152
    def test_repeat_zero_bound(self, channel):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
153
        x = self.loc[..., 0]
154
        y = self.loc[..., 1]
155
        expected = 3 * x + y
156
        expected = (
157
            expected
158
            * tf.cast(x > self.x_min, tf.float32)
159
            * tf.cast(x <= self.x_max, tf.float32)
160
        )
161
        expected = (
162
            expected
163
            * tf.cast(y > self.y_min, tf.float32)
164
            * tf.cast(y <= self.y_max, tf.float32)
165
        )
166
167
        vol = self.vol
168
        if channel > 0:
169
            vol = tf.repeat(vol[..., None], channel, axis=-1)
170
            expected = tf.repeat(expected[..., None], channel, axis=-1)
171
172
        got = layer_util.resample(vol=vol, loc=self.loc, zero_boundary=True)
173
        assert is_equal_tf(expected, got)
174
175
    def test_shape_error(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
176
        vol = tf.constant(np.array([[0]], dtype=np.float32))  # shape = [1,1]
177
        loc = tf.constant(np.array([[0, 0], [0, 0]], dtype=np.float32))  # shape = [2,2]
178
        with pytest.raises(ValueError) as err_info:
179
            layer_util.resample(vol=vol, loc=loc)
180
        assert "vol shape inconsistent with loc" in str(err_info.value)
181
182
    def test_interpolation_error(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
183
        interpolation = "nearest"
184
        vol = tf.constant(np.array([[0]], dtype=np.float32))  # shape = [1,1]
185
        loc = tf.constant(np.array([[0, 0], [0, 0]], dtype=np.float32))  # shape = [2,2]
186
        with pytest.raises(ValueError) as err_info:
187
            layer_util.resample(vol=vol, loc=loc, interpolation=interpolation)
188
        assert "resample supports only linear interpolation" in str(err_info.value)
189
190
191
class TestWarpGrid:
192
    """
193
    Test warp_grid by confirming that it generates
194
    appropriate solutions for simple precomputed cases.
195
    """
196
197
    grid = tf.constant(
198
        np.array(
199
            [[[[0, 0, 0], [0, 0, 1], [0, 0, 2]], [[0, 1, 0], [0, 1, 1], [0, 1, 2]]]],
200
            dtype=np.float32,
201
        )
202
    )  # shape = (1, 2, 3, 3)
203
204
    def test_identical(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
205
        theta = tf.constant(np.eye(4, 3).reshape((1, 4, 3)), dtype=tf.float32)
206
        expected = self.grid[None, ...]  # shape = (1, 1, 2, 3, 3)
207
        got = layer_util.warp_grid(grid=self.grid, theta=theta)
208
        assert is_equal_tf(got, expected)
209
210
    def test_non_identical(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
211
        theta = tf.constant(
212
            np.array(
213
                [
214
                    [
215
                        [0.86, 0.75, 0.48],
216
                        [0.07, 0.98, 0.01],
217
                        [0.72, 0.52, 0.97],
218
                        [0.12, 0.4, 0.04],
219
                    ]
220
                ],
221
                dtype=np.float32,
222
            )
223
        )  # shape = (1, 4, 3)
224
        expected = tf.constant(
225
            np.array(
226
                [
227
                    [
228
                        [
229
                            [[0.12, 0.4, 0.04], [0.84, 0.92, 1.01], [1.56, 1.44, 1.98]],
230
                            [[0.19, 1.38, 0.05], [0.91, 1.9, 1.02], [1.63, 2.42, 1.99]],
231
                        ]
232
                    ]
233
                ],
234
                dtype=np.float32,
235
            )
236
        )  # shape = (1, 1, 2, 3, 3)
237
        got = layer_util.warp_grid(grid=self.grid, theta=theta)
238
        assert is_equal_tf(got, expected)
239
240
241
def test_resize3d():
242
    """
243
    Test resize3d by confirming the output shapes.
244
    """
245
246
    # Check resize3d for images with different size and without channel nor batch - Pass
247
    input_shape = (1, 3, 5)
248
    output_shape = (2, 4, 6)
249
    size = (2, 4, 6)
250
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
251
    assert got.shape == output_shape
252
253
    # Check resize3d for images with different size and without channel - Pass
254
    input_shape = (1, 1, 3, 5)
255
    output_shape = (1, 2, 4, 6)
256
    size = (2, 4, 6)
257
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
258
    assert got.shape == output_shape
259
260
    # Check resize3d for images with different size and with one channel - Pass
261
    input_shape = (1, 1, 3, 5, 1)
262
    output_shape = (1, 2, 4, 6, 1)
263
    size = (2, 4, 6)
264
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
265
    assert got.shape == output_shape
266
267
    # Check resize3d for images with different size and with multiple channels - Pass
268
    input_shape = (1, 1, 3, 5, 3)
269
    output_shape = (1, 2, 4, 6, 3)
270
    size = (2, 4, 6)
271
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
272
    assert got.shape == output_shape
273
274
    # Check resize3d for images with the same size and without channel nor batch - Pass
275
    input_shape = (1, 3, 5)
276
    output_shape = (1, 3, 5)
277
    size = (1, 3, 5)
278
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
279
    assert got.shape == output_shape
280
281
    # Check resize3d for images with the same size and without channel - Pass
282
    input_shape = (1, 1, 3, 5)
283
    output_shape = (1, 1, 3, 5)
284
    size = (1, 3, 5)
285
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
286
    assert got.shape == output_shape
287
288
    # Check resize3d for images with the same size and with one channel - Pass
289
    input_shape = (1, 1, 3, 5, 1)
290
    output_shape = (1, 1, 3, 5, 1)
291
    size = (1, 3, 5)
292
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
293
    assert got.shape == output_shape
294
295
    # Check resize3d for images with the same size and with multiple channels - Pass
296
    input_shape = (1, 1, 3, 5, 3)
297
    output_shape = (1, 1, 3, 5, 3)
298
    size = (1, 3, 5)
299
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
300
    assert got.shape == output_shape
301
302
    # Check resize3d for proper image dimensions - Fail
303
    input_shape = (1, 1)
304
    size = (1, 1, 1)
305
    with pytest.raises(ValueError) as err_info:
306
        layer_util.resize3d(image=tf.ones(input_shape), size=size)
307
    assert "resize3d takes input image of dimension 3 or 4 or 5" in str(err_info.value)
308
309
    # Check resize3d for proper size - Fail
310
    input_shape = (1, 1, 1)
311
    size = (1, 1)
312
    with pytest.raises(ValueError) as err_info:
313
        layer_util.resize3d(image=tf.ones(input_shape), size=size)
314
    assert "resize3d takes size of type tuple/list and of length 3" in str(
315
        err_info.value
316
    )
317
318
319
class TestGaussianFilter3D:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
320
    @pytest.mark.parametrize(
321
        "kernel_sigma, kernel_size",
322
        [
323
            ((1, 1, 1), (3, 3, 3, 3, 3)),
324
            ((2, 2, 2), (7, 7, 7, 3, 3)),
325
            ((5, 5, 5), (15, 15, 15, 3, 3)),
326
            (1, (3, 3, 3, 3, 3)),
327
            (2, (7, 7, 7, 3, 3)),
328
            (5, (15, 15, 15, 3, 3)),
329
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
330
    )
331
    def test_kernel_size(self, kernel_sigma, kernel_size):
332
        filter = layer_util.gaussian_filter_3d(kernel_sigma)
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in filter.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
333
        assert filter.shape == kernel_size
334
335
    @pytest.mark.parametrize(
336
        "kernel_sigma",
337
        [(1, 1, 1), (2, 2, 2), (5, 5, 5)],
338
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
339
    def test_sum(self, kernel_sigma):
340
        filter = layer_util.gaussian_filter_3d(kernel_sigma)
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in filter.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
341
        assert np.allclose(np.sum(filter), 3, atol=1e-3)
342