Passed
Pull Request — main (#656)
by Yunguan
03:01
created

test.unit.test_layer_util.test_resize3d()   B

Complexity

Conditions 3

Size

Total Lines 75
Code Lines 52

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 52
dl 0
loc 75
rs 8.5709
c 0
b 0
f 0
cc 3
nop 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""
2
Tests for deepreg/model/layer_util.py in
3
pytest style
4
"""
5
from test.unit.util import is_equal_tf
6
7
import numpy as np
8
import pytest
9
import tensorflow as tf
10
11
import deepreg.model.layer_util as layer_util
12
13
14
def test_get_reference_grid():
15
    """
16
    Test get_reference_grid by confirming that it generates
17
    a sample grid test case to is_equal_tf's tolerance level.
18
    """
19
    want = tf.constant(
20
        np.array(
21
            [[[[0, 0, 0], [0, 0, 1], [0, 0, 2]], [[0, 1, 0], [0, 1, 1], [0, 1, 2]]]],
22
            dtype=np.float32,
23
        )
24
    )
25
    get = layer_util.get_reference_grid(grid_size=[1, 2, 3])
26
    assert is_equal_tf(want, get)
27
28
29
def test_get_n_bits_combinations():
30
    """
31
    Test get_n_bits_combinations by confirming that it generates
32
    appropriate solutions for 1D, 2D, and 3D cases.
33
    """
34
    # Check n=1 - Pass
35
    assert layer_util.get_n_bits_combinations(1) == [[0], [1]]
36
    # Check n=2 - Pass
37
    assert layer_util.get_n_bits_combinations(2) == [[0, 0], [0, 1], [1, 0], [1, 1]]
38
39
    # Check n=3 - Pass
40
    assert layer_util.get_n_bits_combinations(3) == [
41
        [0, 0, 0],
42
        [0, 0, 1],
43
        [0, 1, 0],
44
        [0, 1, 1],
45
        [1, 0, 0],
46
        [1, 0, 1],
47
        [1, 1, 0],
48
        [1, 1, 1],
49
    ]
50
51
52
class TestPyramidCombination:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
53 View Code Duplication
    def test_1d(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
introduced by
Missing function or method docstring
Loading history...
54
        weights = tf.constant(np.array([[0.2]], dtype=np.float32))
55
        values = tf.constant(np.array([[1], [2]], dtype=np.float32))
56
57
        # expected = 1 * 0.2 + 2 * 2
58
        expected = tf.constant(np.array([1.8], dtype=np.float32))
59
        got = layer_util.pyramid_combination(
60
            values=values, weight_floor=weights, weight_ceil=1 - weights
61
        )
62
        assert is_equal_tf(got, expected)
63
64 View Code Duplication
    def test_2d(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
introduced by
Missing function or method docstring
Loading history...
65
        weights = tf.constant(np.array([[0.2], [0.3]], dtype=np.float32))
66
        values = tf.constant(
67
            np.array(
68
                [
69
                    [1],  # value at corner (0, 0), weight = 0.2 * 0.3
70
                    [2],  # value at corner (0, 1), weight = 0.2 * 0.7
71
                    [3],  # value at corner (1, 0), weight = 0.8 * 0.3
72
                    [4],  # value at corner (1, 1), weight = 0.8 * 0.7
73
                ],
74
                dtype=np.float32,
75
            )
76
        )
77
        # expected = 1 * 0.2 * 0.3
78
        #          + 2 * 0.2 * 0.7
79
        #          + 3 * 0.8 * 0.3
80
        #          + 4 * 0.8 * 0.7
81
        expected = tf.constant(np.array([3.3], dtype=np.float32))
82
        got = layer_util.pyramid_combination(
83
            values=values, weight_floor=weights, weight_ceil=1 - weights
84
        )
85
        assert is_equal_tf(got, expected)
86
87
    def test_error_dim(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
88
        weights = tf.constant(np.array([[[0.2]], [[0.2]]], dtype=np.float32))
89
        values = tf.constant(np.array([[1], [2]], dtype=np.float32))
90
        with pytest.raises(ValueError) as err_info:
91
            layer_util.pyramid_combination(
92
                values=values, weight_floor=weights, weight_ceil=1 - weights
93
            )
94
        assert (
95
            "In pyramid_combination, elements of values, weight_floor, "
96
            "and weight_ceil should have same dimension" in str(err_info.value)
97
        )
98
99
    def test_error_len(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
100
        weights = tf.constant(np.array([[0.2]], dtype=np.float32))
101
        values = tf.constant(np.array([[1]], dtype=np.float32))
102
        with pytest.raises(ValueError) as err_info:
103
            layer_util.pyramid_combination(
104
                values=values, weight_floor=weights, weight_ceil=1 - weights
105
            )
106
        assert (
107
            "In pyramid_combination, num_dim = len(weight_floor), "
108
            "len(values) must be 2 ** num_dim" in str(err_info.value)
109
        )
110
111
112
class TestLinearResample:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
113
    x_min, x_max = 0, 2
114
    y_min, y_max = 0, 2
115
    # vol are values on grid [0,2]x[0,2]
116
    # values on each point is 3x+y
117
    # shape = (1,3,3)
118
    vol = tf.constant(np.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]), dtype=tf.float32)
119
    # loc are some points, especially
120
    # shape = (1,4,3,2)
121
    loc = tf.constant(
122
        np.array(
123
            [
124
                [
125
                    [[0, 0], [0, 1], [1, 2]],  # boundary corners
126
                    [[0.4, 0], [0.5, 2], [2, 1.7]],  # boundary edge
127
                    [[-0.4, 0.7], [0, 3], [2, 3]],  # outside boundary
128
                    [[0.4, 0.7], [1, 1], [0.6, 0.3]],  # internal
129
                ]
130
            ]
131
        ),
132
        dtype=tf.float32,
133
    )
134
135
    @pytest.mark.parametrize("channel", [0, 1, 2])
136
    def test_repeat_extrapolation(self, channel):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
137
        x = self.loc[..., 0]
138
        y = self.loc[..., 1]
139
        x = tf.clip_by_value(x, self.x_min, self.x_max)
140
        y = tf.clip_by_value(y, self.y_min, self.y_max)
141
        expected = 3 * x + y
142
143
        vol = self.vol
144
        if channel > 0:
145
            vol = tf.repeat(vol[..., None], channel, axis=-1)
146
            expected = tf.repeat(expected[..., None], channel, axis=-1)
147
148
        got = layer_util.resample(vol=vol, loc=self.loc, zero_boundary=False)
149
        assert is_equal_tf(expected, got)
150
151
    @pytest.mark.parametrize("channel", [0, 1, 2])
152
    def test_repeat_zero_bound(self, channel):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
153
        x = self.loc[..., 0]
154
        y = self.loc[..., 1]
155
        expected = 3 * x + y
156
        expected = (
157
            expected
158
            * tf.cast(x > self.x_min, tf.float32)
159
            * tf.cast(x <= self.x_max, tf.float32)
160
        )
161
        expected = (
162
            expected
163
            * tf.cast(y > self.y_min, tf.float32)
164
            * tf.cast(y <= self.y_max, tf.float32)
165
        )
166
167
        vol = self.vol
168
        if channel > 0:
169
            vol = tf.repeat(vol[..., None], channel, axis=-1)
170
            expected = tf.repeat(expected[..., None], channel, axis=-1)
171
172
        got = layer_util.resample(vol=vol, loc=self.loc, zero_boundary=True)
173
        assert is_equal_tf(expected, got)
174
175
    def test_shape_error(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
176
        vol = tf.constant(np.array([[0]], dtype=np.float32))  # shape = [1,1]
177
        loc = tf.constant(np.array([[0, 0], [0, 0]], dtype=np.float32))  # shape = [2,2]
178
        with pytest.raises(ValueError) as err_info:
179
            layer_util.resample(vol=vol, loc=loc)
180
        assert "vol shape inconsistent with loc" in str(err_info.value)
181
182
    def test_interpolation_error(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
183
        interpolation = "nearest"
184
        vol = tf.constant(np.array([[0]], dtype=np.float32))  # shape = [1,1]
185
        loc = tf.constant(np.array([[0, 0], [0, 0]], dtype=np.float32))  # shape = [2,2]
186
        with pytest.raises(ValueError) as err_info:
187
            layer_util.resample(vol=vol, loc=loc, interpolation=interpolation)
188
        assert "resample supports only linear interpolation" in str(err_info.value)
189
190
191
class TestWarpGrid:
192
    """
193
    Test warp_grid by confirming that it generates
194
    appropriate solutions for simple precomputed cases.
195
    """
196
197
    grid = tf.constant(
198
        np.array(
199
            [[[[0, 0, 0], [0, 0, 1], [0, 0, 2]], [[0, 1, 0], [0, 1, 1], [0, 1, 2]]]],
200
            dtype=np.float32,
201
        )
202
    )  # shape = (1, 2, 3, 3)
203
204
    def test_identical(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
205
        theta = tf.constant(np.eye(4, 3).reshape((1, 4, 3)), dtype=tf.float32)
206
        expected = self.grid[None, ...]  # shape = (1, 1, 2, 3, 3)
207
        got = layer_util.warp_grid(grid=self.grid, theta=theta)
208
        assert is_equal_tf(got, expected)
209
210
    def test_non_identical(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
211
        theta = tf.constant(
212
            np.array(
213
                [
214
                    [
215
                        [0.86, 0.75, 0.48],
216
                        [0.07, 0.98, 0.01],
217
                        [0.72, 0.52, 0.97],
218
                        [0.12, 0.4, 0.04],
219
                    ]
220
                ],
221
                dtype=np.float32,
222
            )
223
        )  # shape = (1, 4, 3)
224
        expected = tf.constant(
225
            np.array(
226
                [
227
                    [
228
                        [
229
                            [[0.12, 0.4, 0.04], [0.84, 0.92, 1.01], [1.56, 1.44, 1.98]],
230
                            [[0.19, 1.38, 0.05], [0.91, 1.9, 1.02], [1.63, 2.42, 1.99]],
231
                        ]
232
                    ]
233
                ],
234
                dtype=np.float32,
235
            )
236
        )  # shape = (1, 1, 2, 3, 3)
237
        got = layer_util.warp_grid(grid=self.grid, theta=theta)
238
        assert is_equal_tf(got, expected)
239
240
241
class TestGaussianFilter3D:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
242
    @pytest.mark.parametrize(
243
        "kernel_sigma, kernel_size",
244
        [
245
            ((1, 1, 1), (3, 3, 3, 3, 3)),
246
            ((2, 2, 2), (7, 7, 7, 3, 3)),
247
            ((5, 5, 5), (15, 15, 15, 3, 3)),
248
            (1, (3, 3, 3, 3, 3)),
249
            (2, (7, 7, 7, 3, 3)),
250
            (5, (15, 15, 15, 3, 3)),
251
        ],
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
252
    )
253
    def test_kernel_size(self, kernel_sigma, kernel_size):
254
        filter = layer_util.gaussian_filter_3d(kernel_sigma)
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in filter.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
255
        assert filter.shape == kernel_size
256
257
    @pytest.mark.parametrize(
258
        "kernel_sigma",
259
        [(1, 1, 1), (2, 2, 2), (5, 5, 5)],
260
    )
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
261
    def test_sum(self, kernel_sigma):
262
        filter = layer_util.gaussian_filter_3d(kernel_sigma)
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in filter.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
263
        assert np.allclose(np.sum(filter), 3, atol=1e-3)
264