Passed
Pull Request — main (#656)
by Yunguan
02:43
created

TestRandomTransformation.test_get_config()   A

Complexity

Conditions 1

Size

Total Lines 16
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 10
dl 0
loc 16
rs 9.9
c 0
b 0
f 0
cc 1
nop 2
1
"""
2
Tests for deepreg/dataset/preprocess.py in
3
pytest style
4
5
Some internals of the _gen_transform, _transform and
6
transform function, such as:
7
    - layer_util.random_transform_generator
8
    - layer_util.warp_grid
9
    - layer_util.resample
10
Are assumed working, and are tested separately in
11
test_layer_util.py; as such we just check output size here.
12
"""
13
from test.unit.util import is_equal_np, is_equal_tf
14
15
import numpy as np
16
import pytest
17
import tensorflow as tf
18
19
import deepreg.dataset
20
import deepreg.dataset.preprocess as preprocess
21
22
23
@pytest.mark.parametrize(
24
    ("moving_input_size", "fixed_input_size", "moving_image_size", "fixed_image_size"),
25
    [
26
        ((1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)),
27
        ((3, 4, 5), (4, 5, 6), (1, 2, 3), (2, 3, 4)),
28
        ((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)),
29
    ],
30
)
31
@pytest.mark.parametrize("labeled", [True, False])
32
def test_resize_inputs(
33
    moving_input_size: tuple,
34
    fixed_input_size: tuple,
35
    moving_image_size: tuple,
36
    fixed_image_size: tuple,
37
    labeled: bool,
38
):
39
    """
40
    Check return shapes.
41
42
    :param moving_input_size: input moving image/label shape
43
    :param fixed_input_size: input fixed image/label shape
44
    :param moving_image_size: output moving image/label shape
45
    :param fixed_image_size: output fixed image/label shape
46
    :param labeled: if data is labeled
47
    """
48
    num_indices = 2
49
50
    moving_image = tf.random.uniform(moving_input_size)
51
    fixed_image = tf.random.uniform(fixed_input_size)
52
    indices = tf.ones((num_indices,))
53
    inputs = dict(moving_image=moving_image, fixed_image=fixed_image, indices=indices)
54
    if labeled:
55
        moving_label = tf.random.uniform(moving_input_size)
56
        fixed_label = tf.random.uniform(fixed_input_size)
57
        inputs["moving_label"] = moving_label
58
        inputs["fixed_label"] = fixed_label
59
60
    outputs = preprocess.resize_inputs(inputs, moving_image_size, fixed_image_size)
61
    assert inputs["indices"].shape == outputs["indices"].shape
62
    for k in inputs:
63
        if k == "indices":
64
            assert outputs[k].shape == inputs[k].shape
65
            continue
66
        expected_shape = moving_image_size if "moving" in k else fixed_image_size
67
        assert outputs[k].shape == expected_shape
68
69
70
def test_random_transform_3d_get_config():
71
    """Check config values."""
72
    config = dict(
73
        moving_image_size=(1, 2, 3),
74
        fixed_image_size=(2, 3, 4),
75
        batch_size=3,
76
        name="TestRandomTransformation3D",
77
    )
78
    expected = {"trainable": False, "dtype": "float32", **config}
79
    transform = preprocess.RandomTransformation3D(**config)
80
    got = transform.get_config()
81
82
    assert got == expected
83
84
85
class TestRandomTransformation:
86
    """Test all functions of RandomTransformation class."""
87
88
    moving_image_size = (1, 2, 3)
89
    fixed_image_size = (2, 3, 4)
90
    batch_size = 2
91
    scale = 0.2
92
    num_indices = 3
93
    name = "TestTransformation"
94
    common_config = dict(
95
        moving_image_size=moving_image_size,
96
        fixed_image_size=fixed_image_size,
97
        batch_size=batch_size,
98
        name=name,
99
    )
100
    extra_config_dict = dict(
101
        affine=dict(scale=0.2), ddf=dict(field_strength=0.2, low_res_size=(1, 2, 3))
102
    )
103
    layer_cls_dict = dict(
104
        affine=preprocess.RandomAffineTransform3D,
105
        ddf=preprocess.RandomDDFTransform3D,
106
    )
107
108
    def build_layer(self, name: str) -> preprocess.RandomTransformation3D:
109
        """
110
        Build a layer given the layer name.
111
112
        :param name: name of the layer
113
        :return: built layer object
114
        """
115
        config = {**self.common_config, **self.extra_config_dict[name]}
116
        return self.layer_cls_dict[name](**config)
117
118
    @pytest.mark.parametrize("name", ["affine", "ddf"])
119
    def test_get_config(self, name: str):
120
        """
121
        Check config values.
122
123
        :param name: name of the layer
124
        """
125
        layer = self.build_layer(name)
126
        got = layer.get_config()
127
        expected = {
128
            "trainable": False,
129
            "dtype": "float32",
130
            **self.common_config,
131
            **self.extra_config_dict[name],
132
        }
133
        assert got == expected
134
135
    @pytest.mark.parametrize(
136
        ("name", "moving_param_shape", "fixed_param_shape"),
137
        [
138
            ("affine", (4, 3), (4, 3)),
139
            ("ddf", (*moving_image_size, 3), (*fixed_image_size, 3)),
140
        ],
141
    )
142
    def test_gen_transform_params(
143
        self, name: str, moving_param_shape: tuple, fixed_param_shape: tuple
144
    ):
145
        """
146
        Check return shapes and moving/fixed params should be different.
147
148
        :param name: name of the layer
149
        :param moving_param_shape: params shape for moving image/label
150
        :param fixed_param_shape: params shape for fixed image/label
151
        """
152
        layer = self.build_layer(name)
153
        moving, fixed = layer.gen_transform_params()
154
        assert moving.shape == (self.batch_size, *moving_param_shape)
155
        assert fixed.shape == (self.batch_size, *fixed_param_shape)
156
        assert not is_equal_np(moving, fixed)
157
158
    @pytest.mark.parametrize("name", ["affine", "ddf"])
159
    def test_transform(self, name: str):
160
        """
161
        Check return shapes.
162
163
        :param name: name of the layer
164
        """
165
        layer = self.build_layer(name)
166
        moving_image = tf.random.uniform(
167
            shape=(self.batch_size, *self.moving_image_size)
168
        )
169
        moving_params, _ = layer.gen_transform_params()
170
        transformed = layer.transform(
171
            image=moving_image,
172
            grid_ref=layer.moving_grid_ref,
173
            params=moving_params,
174
        )
175
        assert transformed.shape == moving_image.shape
176
177
    @pytest.mark.parametrize("name", ["affine", "ddf"])
178
    @pytest.mark.parametrize("labeled", [True, False])
179
    def test_call(self, name: str, labeled: bool):
180
        """
181
        Check return shapes.
182
183
        :param name: name of the layer
184
        :param labeled: if data is labeled
185
        """
186
        layer = self.build_layer(name)
187
188
        moving_shape = (self.batch_size, *self.moving_image_size)
189
        fixed_shape = (self.batch_size, *self.fixed_image_size)
190
        moving_image = tf.random.uniform(moving_shape)
191
        fixed_image = tf.random.uniform(fixed_shape)
192
        indices = tf.ones((self.batch_size, self.num_indices))
193
        inputs = dict(
194
            moving_image=moving_image, fixed_image=fixed_image, indices=indices
195
        )
196
        if labeled:
197
            moving_label = tf.random.uniform(moving_shape)
198
            fixed_label = tf.random.uniform(fixed_shape)
199
            inputs["moving_label"] = moving_label
200
            inputs["fixed_label"] = fixed_label
201
202
        outputs = layer.call(inputs)
203
        for k in inputs:
204
            assert outputs[k].shape == inputs[k].shape
205
206
207
def test_random_transform_generator():
208
    """
209
    Test random_transform_generator by confirming that it generates
210
    appropriate solutions and output sizes for seeded examples.
211
    """
212
    # Check shapes are correct Batch Size = 1 - Pass
213
    batch_size = 1
214
    transforms = deepreg.dataset.preprocess.gen_rand_affine_transform(batch_size, 0)
215
    assert transforms.shape == (batch_size, 4, 3)
216
217
    # Check numerical outputs are correct for a given seed - Pass
218
    batch_size = 1
219
    scale = 0.1
220
    seed = 0
221
    expected = tf.constant(
222
        np.array(
223
            [
224
                [
225
                    [9.4661278e-01, -3.8267835e-03, 3.6934228e-03],
226
                    [5.5613145e-03, 9.8034811e-01, -1.8044969e-02],
227
                    [1.9651605e-04, 1.4576728e-02, 9.6243286e-01],
228
                    [-2.5107686e-03, 1.9579126e-02, -1.2195010e-02],
229
                ]
230
            ],
231
            dtype=np.float32,
232
        )
233
    )  # shape = (1, 4, 3)
234
    got = deepreg.dataset.preprocess.gen_rand_affine_transform(
235
        batch_size=batch_size, scale=scale, seed=seed
236
    )
237
    assert is_equal_tf(got, expected)
238