Passed
Pull Request — main (#656)
by Yunguan
02:44
created

test.unit.test_layer_util.test_warp_image_ddf()   A

Complexity

Conditions 4

Size

Total Lines 47
Code Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 30
dl 0
loc 47
rs 9.16
c 0
b 0
f 0
cc 4
nop 0
1
"""
2
Tests for deepreg/model/layer_util.py in
3
pytest style
4
"""
5
from test.unit.util import is_equal_tf
6
7
import numpy as np
8
import pytest
9
import tensorflow as tf
10
11
import deepreg.model.layer_util as layer_util
12
13
14
def test_get_reference_grid():
15
    """
16
    Test get_reference_grid by confirming that it generates
17
    a sample grid test case to is_equal_tf's tolerance level.
18
    """
19
    want = tf.constant(
20
        np.array(
21
            [[[[0, 0, 0], [0, 0, 1], [0, 0, 2]], [[0, 1, 0], [0, 1, 1], [0, 1, 2]]]],
22
            dtype=np.float32,
23
        )
24
    )
25
    get = layer_util.get_reference_grid(grid_size=[1, 2, 3])
26
    assert is_equal_tf(want, get)
27
28
29
def test_get_n_bits_combinations():
30
    """
31
    Test get_n_bits_combinations by confirming that it generates
32
    appropriate solutions for 1D, 2D, and 3D cases.
33
    """
34
    # Check n=1 - Pass
35
    assert layer_util.get_n_bits_combinations(1) == [[0], [1]]
36
    # Check n=2 - Pass
37
    assert layer_util.get_n_bits_combinations(2) == [[0, 0], [0, 1], [1, 0], [1, 1]]
38
39
    # Check n=3 - Pass
40
    assert layer_util.get_n_bits_combinations(3) == [
41
        [0, 0, 0],
42
        [0, 0, 1],
43
        [0, 1, 0],
44
        [0, 1, 1],
45
        [1, 0, 0],
46
        [1, 0, 1],
47
        [1, 1, 0],
48
        [1, 1, 1],
49
    ]
50
51
52
class TestPyramidCombination:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
53 View Code Duplication
    def test_1d(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
introduced by
Missing function or method docstring
Loading history...
54
        weights = tf.constant(np.array([[0.2]], dtype=np.float32))
55
        values = tf.constant(np.array([[1], [2]], dtype=np.float32))
56
57
        # expected = 1 * 0.2 + 2 * 2
58
        expected = tf.constant(np.array([1.8], dtype=np.float32))
59
        got = layer_util.pyramid_combination(
60
            values=values, weight_floor=weights, weight_ceil=1 - weights
61
        )
62
        assert is_equal_tf(got, expected)
63
64 View Code Duplication
    def test_2d(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
introduced by
Missing function or method docstring
Loading history...
65
        weights = tf.constant(np.array([[0.2], [0.3]], dtype=np.float32))
66
        values = tf.constant(
67
            np.array(
68
                [
69
                    [1],  # value at corner (0, 0), weight = 0.2 * 0.3
70
                    [2],  # value at corner (0, 1), weight = 0.2 * 0.7
71
                    [3],  # value at corner (1, 0), weight = 0.8 * 0.3
72
                    [4],  # value at corner (1, 1), weight = 0.8 * 0.7
73
                ],
74
                dtype=np.float32,
75
            )
76
        )
77
        # expected = 1 * 0.2 * 0.3
78
        #          + 2 * 0.2 * 0.7
79
        #          + 3 * 0.8 * 0.3
80
        #          + 4 * 0.8 * 0.7
81
        expected = tf.constant(np.array([3.3], dtype=np.float32))
82
        got = layer_util.pyramid_combination(
83
            values=values, weight_floor=weights, weight_ceil=1 - weights
84
        )
85
        assert is_equal_tf(got, expected)
86
87
    def test_error_dim(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
88
        weights = tf.constant(np.array([[[0.2]], [[0.2]]], dtype=np.float32))
89
        values = tf.constant(np.array([[1], [2]], dtype=np.float32))
90
        with pytest.raises(ValueError) as err_info:
91
            layer_util.pyramid_combination(
92
                values=values, weight_floor=weights, weight_ceil=1 - weights
93
            )
94
        assert (
95
            "In pyramid_combination, elements of values, weight_floor, "
96
            "and weight_ceil should have same dimension" in str(err_info.value)
97
        )
98
99
    def test_error_len(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
100
        weights = tf.constant(np.array([[0.2]], dtype=np.float32))
101
        values = tf.constant(np.array([[1]], dtype=np.float32))
102
        with pytest.raises(ValueError) as err_info:
103
            layer_util.pyramid_combination(
104
                values=values, weight_floor=weights, weight_ceil=1 - weights
105
            )
106
        assert (
107
            "In pyramid_combination, num_dim = len(weight_floor), "
108
            "len(values) must be 2 ** num_dim" in str(err_info.value)
109
        )
110
111
112
class TestLinearResample:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
113
    x_min, x_max = 0, 2
114
    y_min, y_max = 0, 2
115
    # vol are values on grid [0,2]x[0,2]
116
    # values on each point is 3x+y
117
    # shape = (1,3,3)
118
    vol = tf.constant(np.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]), dtype=tf.float32)
119
    # loc are some points, especially
120
    # shape = (1,4,3,2)
121
    loc = tf.constant(
122
        np.array(
123
            [
124
                [
125
                    [[0, 0], [0, 1], [1, 2]],  # boundary corners
126
                    [[0.4, 0], [0.5, 2], [2, 1.7]],  # boundary edge
127
                    [[-0.4, 0.7], [0, 3], [2, 3]],  # outside boundary
128
                    [[0.4, 0.7], [1, 1], [0.6, 0.3]],  # internal
129
                ]
130
            ]
131
        ),
132
        dtype=tf.float32,
133
    )
134
135
    @pytest.mark.parametrize("channel", [0, 1, 2])
136
    def test_repeat_extrapolation(self, channel):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
137
        x = self.loc[..., 0]
138
        y = self.loc[..., 1]
139
        x = tf.clip_by_value(x, self.x_min, self.x_max)
140
        y = tf.clip_by_value(y, self.y_min, self.y_max)
141
        expected = 3 * x + y
142
143
        vol = self.vol
144
        if channel > 0:
145
            vol = tf.repeat(vol[..., None], channel, axis=-1)
146
            expected = tf.repeat(expected[..., None], channel, axis=-1)
147
148
        got = layer_util.resample(vol=vol, loc=self.loc, zero_boundary=False)
149
        assert is_equal_tf(expected, got)
150
151
    @pytest.mark.parametrize("channel", [0, 1, 2])
152
    def test_repeat_zero_bound(self, channel):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
153
        x = self.loc[..., 0]
154
        y = self.loc[..., 1]
155
        expected = 3 * x + y
156
        expected = (
157
            expected
158
            * tf.cast(x > self.x_min, tf.float32)
159
            * tf.cast(x <= self.x_max, tf.float32)
160
        )
161
        expected = (
162
            expected
163
            * tf.cast(y > self.y_min, tf.float32)
164
            * tf.cast(y <= self.y_max, tf.float32)
165
        )
166
167
        vol = self.vol
168
        if channel > 0:
169
            vol = tf.repeat(vol[..., None], channel, axis=-1)
170
            expected = tf.repeat(expected[..., None], channel, axis=-1)
171
172
        got = layer_util.resample(vol=vol, loc=self.loc, zero_boundary=True)
173
        assert is_equal_tf(expected, got)
174
175
    def test_shape_error(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
176
        vol = tf.constant(np.array([[0]], dtype=np.float32))  # shape = [1,1]
177
        loc = tf.constant(np.array([[0, 0], [0, 0]], dtype=np.float32))  # shape = [2,2]
178
        with pytest.raises(ValueError) as err_info:
179
            layer_util.resample(vol=vol, loc=loc)
180
        assert "vol shape inconsistent with loc" in str(err_info.value)
181
182
    def test_interpolation_error(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
183
        interpolation = "nearest"
184
        vol = tf.constant(np.array([[0]], dtype=np.float32))  # shape = [1,1]
185
        loc = tf.constant(np.array([[0, 0], [0, 0]], dtype=np.float32))  # shape = [2,2]
186
        with pytest.raises(ValueError) as err_info:
187
            layer_util.resample(vol=vol, loc=loc, interpolation=interpolation)
188
        assert "resample supports only linear interpolation" in str(err_info.value)
189
190
191
def test_random_transform_generator():
192
    """
193
    Test random_transform_generator by confirming that it generates
194
    appropriate solutions and output sizes for seeded examples.
195
    """
196
    # Check shapes are correct Batch Size = 1 - Pass
197
    batch_size = 1
198
    transforms = layer_util.gen_rand_affine_transform(batch_size, 0)
199
    assert transforms.shape == (batch_size, 4, 3)
200
201
    # Check numerical outputs are correct for a given seed - Pass
202
    batch_size = 1
203
    scale = 0.1
204
    seed = 0
205
    expected = tf.constant(
206
        np.array(
207
            [
208
                [
209
                    [9.4661278e-01, -3.8267835e-03, 3.6934228e-03],
210
                    [5.5613145e-03, 9.8034811e-01, -1.8044969e-02],
211
                    [1.9651605e-04, 1.4576728e-02, 9.6243286e-01],
212
                    [-2.5107686e-03, 1.9579126e-02, -1.2195010e-02],
213
                ]
214
            ],
215
            dtype=np.float32,
216
        )
217
    )  # shape = (1, 4, 3)
218
    got = layer_util.gen_rand_affine_transform(
219
        batch_size=batch_size, scale=scale, seed=seed
220
    )
221
    assert is_equal_tf(got, expected)
222
223
224
class TestWarpGrid:
225
    """
226
    Test warp_grid by confirming that it generates
227
    appropriate solutions for simple precomputed cases.
228
    """
229
230
    grid = tf.constant(
231
        np.array(
232
            [[[[0, 0, 0], [0, 0, 1], [0, 0, 2]], [[0, 1, 0], [0, 1, 1], [0, 1, 2]]]],
233
            dtype=np.float32,
234
        )
235
    )  # shape = (1, 2, 3, 3)
236
237
    def test_identical(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
238
        theta = tf.constant(np.eye(4, 3).reshape((1, 4, 3)), dtype=tf.float32)
239
        expected = self.grid[None, ...]  # shape = (1, 1, 2, 3, 3)
240
        got = layer_util.warp_grid(grid=self.grid, theta=theta)
241
        assert is_equal_tf(got, expected)
242
243
    def test_non_identical(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
244
        theta = tf.constant(
245
            np.array(
246
                [
247
                    [
248
                        [0.86, 0.75, 0.48],
249
                        [0.07, 0.98, 0.01],
250
                        [0.72, 0.52, 0.97],
251
                        [0.12, 0.4, 0.04],
252
                    ]
253
                ],
254
                dtype=np.float32,
255
            )
256
        )  # shape = (1, 4, 3)
257
        expected = tf.constant(
258
            np.array(
259
                [
260
                    [
261
                        [
262
                            [[0.12, 0.4, 0.04], [0.84, 0.92, 1.01], [1.56, 1.44, 1.98]],
263
                            [[0.19, 1.38, 0.05], [0.91, 1.9, 1.02], [1.63, 2.42, 1.99]],
264
                        ]
265
                    ]
266
                ],
267
                dtype=np.float32,
268
            )
269
        )  # shape = (1, 1, 2, 3, 3)
270
        got = layer_util.warp_grid(grid=self.grid, theta=theta)
271
        assert is_equal_tf(got, expected)
272
273
274
def test_resize3d():
275
    """
276
    Test resize3d by confirming the output shapes.
277
    """
278
279
    # Check resize3d for images with different size and without channel nor batch - Pass
280
    input_shape = (1, 3, 5)
281
    output_shape = (2, 4, 6)
282
    size = (2, 4, 6)
283
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
284
    assert got.shape == output_shape
285
286
    # Check resize3d for images with different size and without channel - Pass
287
    input_shape = (1, 1, 3, 5)
288
    output_shape = (1, 2, 4, 6)
289
    size = (2, 4, 6)
290
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
291
    assert got.shape == output_shape
292
293
    # Check resize3d for images with different size and with one channel - Pass
294
    input_shape = (1, 1, 3, 5, 1)
295
    output_shape = (1, 2, 4, 6, 1)
296
    size = (2, 4, 6)
297
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
298
    assert got.shape == output_shape
299
300
    # Check resize3d for images with different size and with multiple channels - Pass
301
    input_shape = (1, 1, 3, 5, 3)
302
    output_shape = (1, 2, 4, 6, 3)
303
    size = (2, 4, 6)
304
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
305
    assert got.shape == output_shape
306
307
    # Check resize3d for images with the same size and without channel nor batch - Pass
308
    input_shape = (1, 3, 5)
309
    output_shape = (1, 3, 5)
310
    size = (1, 3, 5)
311
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
312
    assert got.shape == output_shape
313
314
    # Check resize3d for images with the same size and without channel - Pass
315
    input_shape = (1, 1, 3, 5)
316
    output_shape = (1, 1, 3, 5)
317
    size = (1, 3, 5)
318
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
319
    assert got.shape == output_shape
320
321
    # Check resize3d for images with the same size and with one channel - Pass
322
    input_shape = (1, 1, 3, 5, 1)
323
    output_shape = (1, 1, 3, 5, 1)
324
    size = (1, 3, 5)
325
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
326
    assert got.shape == output_shape
327
328
    # Check resize3d for images with the same size and with multiple channels - Pass
329
    input_shape = (1, 1, 3, 5, 3)
330
    output_shape = (1, 1, 3, 5, 3)
331
    size = (1, 3, 5)
332
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
333
    assert got.shape == output_shape
334
335
    # Check resize3d for proper image dimensions - Fail
336
    input_shape = (1, 1)
337
    size = (1, 1, 1)
338
    with pytest.raises(ValueError) as err_info:
339
        layer_util.resize3d(image=tf.ones(input_shape), size=size)
340
    assert "resize3d takes input image of dimension 3 or 4 or 5" in str(err_info.value)
341
342
    # Check resize3d for proper size - Fail
343
    input_shape = (1, 1, 1)
344
    size = (1, 1)
345
    with pytest.raises(ValueError) as err_info:
346
        layer_util.resize3d(image=tf.ones(input_shape), size=size)
347
    assert "resize3d takes size of type tuple/list and of length 3" in str(
348
        err_info.value
349
    )
350
351
352
class TestGaussianFilter3D:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
353
    @pytest.mark.parametrize(
354
        "kernel_sigma, kernel_size",
355
        [
356
            ((1, 1, 1), (3, 3, 3, 3, 3)),
357
            ((2, 2, 2), (7, 7, 7, 3, 3)),
358
            ((5, 5, 5), (15, 15, 15, 3, 3)),
359
            (1, (3, 3, 3, 3, 3)),
360
            (2, (7, 7, 7, 3, 3)),
361
            (5, (15, 15, 15, 3, 3)),
362
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
363
    )
364
    def test_kernel_size(self, kernel_sigma, kernel_size):
365
        filter = layer_util.gaussian_filter_3d(kernel_sigma)
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in filter.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
366
        assert filter.shape == kernel_size
367
368
    @pytest.mark.parametrize(
369
        "kernel_sigma",
370
        [(1, 1, 1), (2, 2, 2), (5, 5, 5)],
371
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
372
    def test_sum(self, kernel_sigma):
373
        filter = layer_util.gaussian_filter_3d(kernel_sigma)
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in filter.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
374
        assert np.allclose(np.sum(filter), 3, atol=1e-3)
375