Passed
Pull Request — main (#656)
by Yunguan
02:36
created

deepreg.dataset.preprocess.gen_rand_ddf()   A

Complexity

Conditions 1

Size

Total Lines 26
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 12
dl 0
loc 26
rs 9.8
c 0
b 0
f 0
cc 1
nop 5
1
"""
2
Module containing data augmentation techniques.
3
  - 3D Affine/DDF Transforms for moving and fixed images.
4
"""
5
6
from abc import abstractmethod
7
from typing import Dict
8
9
import numpy as np
10
import tensorflow as tf
11
12
from deepreg.model.layer_util import get_reference_grid, resample, resize3d, warp_grid
13
from deepreg.registry import REGISTRY
14
15
16
class RandomTransformation3D(tf.keras.layers.Layer):
17
    """
18
    An interface for different types of transformation.
19
    """
20
21
    def __init__(
22
        self,
23
        moving_image_size: tuple,
24
        fixed_image_size: tuple,
25
        batch_size: int,
26
        name: str = "RandomTransformation3D",
27
        trainable: bool = False,
28
    ):
29
        """
30
        Abstract class for image transformation.
31
32
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
33
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
34
        :param batch_size: size of mini-batch
35
        :param name: name of layer
36
        :param trainable: if this layer is trainable
37
        """
38
        super().__init__(trainable=trainable, name=name)
39
        self.moving_image_size = moving_image_size
40
        self.fixed_image_size = fixed_image_size
41
        self.batch_size = batch_size
42
        self.moving_grid_ref = get_reference_grid(grid_size=moving_image_size)
43
        self.fixed_grid_ref = get_reference_grid(grid_size=fixed_image_size)
44
45
    @abstractmethod
46
    def gen_transform_params(self) -> (tf.Tensor, tf.Tensor):
47
        """
48
        Generates transformation parameters for moving and fixed image.
49
50
        :return: two tensors
51
        """
52
53
    @staticmethod
54
    @abstractmethod
55
    def transform(
56
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
57
    ) -> tf.Tensor:
58
        """
59
        Transforms the reference grid and then resample the image.
60
61
        :param image: shape = (batch, dim1, dim2, dim3)
62
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
63
        :param params: parameters for transformation
64
        :return: shape = (batch, dim1, dim2, dim3)
65
        """
66
67
    def call(self, inputs: Dict[str, tf.Tensor], **kwargs) -> Dict[str, tf.Tensor]:
68
        """
69
        Creates random params for the input images and their labels,
70
        and params them based on the resampled reference grids.
71
        :param inputs: a dict having multiple tensors
72
            if labeled:
73
                moving_image, shape = (batch, m_dim1, m_dim2, m_dim3)
74
                fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
75
                moving_label, shape = (batch, m_dim1, m_dim2, m_dim3)
76
                fixed_label, shape = (batch, f_dim1, f_dim2, f_dim3)
77
                indices, shape = (batch, num_indices)
78
            else, unlabeled:
79
                moving_image, shape = (batch, m_dim1, m_dim2, m_dim3)
80
                fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
81
                indices, shape = (batch, num_indices)
82
        :param kwargs: other arguments
83
        :return: dictionary with the same structure as inputs
84
        """
85
86
        moving_image = inputs["moving_image"]
87
        fixed_image = inputs["fixed_image"]
88
        indices = inputs["indices"]
89
90
        moving_params, fixed_params = self.gen_transform_params()
91
92
        moving_image = self.transform(moving_image, self.moving_grid_ref, moving_params)
93
        fixed_image = self.transform(fixed_image, self.fixed_grid_ref, fixed_params)
94
95
        if "moving_label" not in inputs:  # unlabeled
96
            return dict(
97
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
98
            )
99
        moving_label = inputs["moving_label"]
100
        fixed_label = inputs["fixed_label"]
101
102
        moving_label = self.transform(moving_label, self.moving_grid_ref, moving_params)
103
        fixed_label = self.transform(fixed_label, self.fixed_grid_ref, fixed_params)
104
105
        return dict(
106
            moving_image=moving_image,
107
            fixed_image=fixed_image,
108
            moving_label=moving_label,
109
            fixed_label=fixed_label,
110
            indices=indices,
111
        )
112
113
    def get_config(self) -> dict:
114
        """Return the config dictionary for recreating this class."""
115
        config = super().get_config()
116
        config["moving_image_size"] = self.moving_image_size
117
        config["fixed_image_size"] = self.fixed_image_size
118
        config["batch_size"] = self.batch_size
119
        return config
120
121
122
@REGISTRY.register_data_augmentation(name="affine")
123
class RandomAffineTransform3D(RandomTransformation3D):
124
    """Apply random affine transformation to moving/fixed images separately."""
125
126
    def __init__(
127
        self,
128
        moving_image_size: tuple,
129
        fixed_image_size: tuple,
130
        batch_size: int,
131
        scale: float = 0.1,
132
        name: str = "RandomAffineTransform3D",
133
        **kwargs,
134
    ):
135
        """
136
        Init.
137
138
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
139
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
140
        :param batch_size: size of mini-batch
141
        :param scale: a positive float controlling the scale of transformation
142
        :param name: name of the layer
143
        :param kwargs: extra arguments
144
        """
145
        super().__init__(
146
            moving_image_size=moving_image_size,
147
            fixed_image_size=fixed_image_size,
148
            batch_size=batch_size,
149
            name=name,
150
            **kwargs,
151
        )
152
        self.scale = scale
153
154
    def get_config(self) -> dict:
155
        """Return the config dictionary for recreating this class."""
156
        config = super().get_config()
157
        config["scale"] = self.scale
158
        return config
159
160
    def gen_transform_params(self) -> (tf.Tensor, tf.Tensor):
161
        """
162
        Function that generates the random 3D transformation parameters
163
        for a batch of data for moving and fixed image.
164
165
        :return: a tuple of tensors, each has shape = (batch, 4, 3)
166
        """
167
        theta = gen_rand_affine_transform(
168
            batch_size=self.batch_size * 2, scale=self.scale
169
        )
170
        return theta[: self.batch_size], theta[self.batch_size :]
171
172
    @staticmethod
173
    def transform(
174
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
175
    ) -> tf.Tensor:
176
        """
177
        Transforms the reference grid and then resample the image.
178
179
        :param image: shape = (batch, dim1, dim2, dim3)
180
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
181
        :param params: shape = (batch, 4, 3)
182
        :return: shape = (batch, dim1, dim2, dim3)
183
        """
184
        return resample(vol=image, loc=warp_grid(grid_ref, params))
185
186
187
@REGISTRY.register_data_augmentation(name="ddf")
188
class RandomDDFTransform3D(RandomTransformation3D):
189
    """Apply random DDF transformation to moving/fixed images separately."""
190
191
    def __init__(
192
        self,
193
        moving_image_size: tuple,
194
        fixed_image_size: tuple,
195
        batch_size: int,
196
        field_strength: int = 1,
197
        low_res_size: tuple = (1, 1, 1),
198
        name: str = "RandomDDFTransform3D",
199
        **kwargs,
200
    ):
201
        """
202
        Creates a DDF transformation for data augmentation.
203
204
        To simulate smooth deformation fields, we interpolate from a low resolution
205
        field of size low_res_size using linear interpolation. The variance of the
206
        deformation field is drawn from a uniform variable
207
        between [0, field_strength].
208
209
        :param moving_image_size: tuple
210
        :param fixed_image_size: tuple
211
        :param batch_size: int
212
        :param field_strength: int = 1. It is used as the upper bound for the
213
        deformation field variance
214
        :param low_res_size: tuple = (1, 1, 1).
215
        :param name: name of layer
216
        :param kwargs: extra arguments
217
        """
218
219
        super().__init__(
220
            moving_image_size=moving_image_size,
221
            fixed_image_size=fixed_image_size,
222
            batch_size=batch_size,
223
            name=name,
224
            **kwargs,
225
        )
226
227
        assert tuple(low_res_size) <= tuple(moving_image_size)
228
        assert tuple(low_res_size) <= tuple(fixed_image_size)
229
230
        self.field_strength = field_strength
231
        self.low_res_size = low_res_size
232
233
    def get_config(self) -> dict:
234
        """Return the config dictionary for recreating this class."""
235
        config = super().get_config()
236
        config["field_strength"] = self.field_strength
237
        config["low_res_size"] = self.low_res_size
238
        return config
239
240
    def gen_transform_params(self) -> (tf.Tensor, tf.Tensor):
241
        """
242
        Generates two random ddf fields for moving and fixed images.
243
244
        :return: tuple, one has shape = (batch, m_dim1, m_dim2, m_dim3, 3)
245
            another one has shape = (batch, f_dim1, f_dim2, f_dim3, 3)
246
        """
247
        kwargs = dict(
248
            batch_size=self.batch_size,
249
            field_strength=self.field_strength,
250
            low_res_size=self.low_res_size,
251
        )
252
        moving = gen_rand_ddf(image_size=self.moving_image_size, **kwargs)
253
        fixed = gen_rand_ddf(image_size=self.fixed_image_size, **kwargs)
254
        return moving, fixed
255
256
    @staticmethod
257
    def transform(
258
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
259
    ) -> tf.Tensor:
260
        """
261
        Transforms the reference grid and then resample the image.
262
263
        :param image: shape = (batch, dim1, dim2, dim3)
264
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
265
        :param params: DDF, shape = (batch, dim1, dim2, dim3, 3)
266
        :return: shape = (batch, dim1, dim2, dim3)
267
        """
268
        return resample(vol=image, loc=grid_ref[None, ...] + params)
269
270
271
def resize_inputs(
272
    inputs: Dict[str, tf.Tensor], moving_image_size: tuple, fixed_image_size: tuple
273
) -> Dict[str, tf.Tensor]:
274
    """
275
    Resize inputs
276
    :param inputs:
277
        if labeled:
278
            moving_image, shape = (None, None, None)
279
            fixed_image, shape = (None, None, None)
280
            moving_label, shape = (None, None, None)
281
            fixed_label, shape = (None, None, None)
282
            indices, shape = (num_indices, )
283
        else, unlabeled:
284
            moving_image, shape = (None, None, None)
285
            fixed_image, shape = (None, None, None)
286
            indices, shape = (num_indices, )
287
    :param moving_image_size: tuple, (m_dim1, m_dim2, m_dim3)
288
    :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
289
    :return:
290
        if labeled:
291
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
292
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
293
            moving_label, shape = (m_dim1, m_dim2, m_dim3)
294
            fixed_label, shape = (f_dim1, f_dim2, f_dim3)
295
            indices, shape = (num_indices, )
296
        else, unlabeled:
297
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
298
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
299
            indices, shape = (num_indices, )
300
    """
301
    moving_image = inputs["moving_image"]
302
    fixed_image = inputs["fixed_image"]
303
    indices = inputs["indices"]
304
305
    moving_image = resize3d(image=moving_image, size=moving_image_size)
306
    fixed_image = resize3d(image=fixed_image, size=fixed_image_size)
307
308
    if "moving_label" not in inputs:  # unlabeled
309
        return dict(moving_image=moving_image, fixed_image=fixed_image, indices=indices)
310
    moving_label = inputs["moving_label"]
311
    fixed_label = inputs["fixed_label"]
312
313
    moving_label = resize3d(image=moving_label, size=moving_image_size)
314
    fixed_label = resize3d(image=fixed_label, size=fixed_image_size)
315
316
    return dict(
317
        moving_image=moving_image,
318
        fixed_image=fixed_image,
319
        moving_label=moving_label,
320
        fixed_label=fixed_label,
321
        indices=indices,
322
    )
323
324
325
def gen_rand_affine_transform(
326
    batch_size: int, scale: float, seed: (int, None) = None
327
) -> tf.Tensor:
328
    """
329
    Function that generates a random 3D transformation parameters for a batch of data.
330
331
    for 3D coordinates, affine transformation is
332
333
    .. code-block:: text
334
335
        [[x' y' z' 1]] = [[x y z 1]] * [[* * * 0]
336
                                        [* * * 0]
337
                                        [* * * 0]
338
                                        [* * * 1]]
339
340
    where each * represents a degree of freedom,
341
    so there are in total 12 degrees of freedom
342
    the equation can be denoted as
343
344
        new = old * T
345
346
    where
347
348
    - new is the transformed coordinates, of shape (1, 4)
349
    - old is the original coordinates, of shape (1, 4)
350
    - T is the transformation matrix, of shape (4, 4)
351
352
    the equation can be simplified to
353
354
    .. code-block:: text
355
356
        [[x' y' z']] = [[x y z 1]] * [[* * *]
357
                                      [* * *]
358
                                      [* * *]
359
                                      [* * *]]
360
361
    so that
362
363
        new = old * T
364
365
    where
366
367
    - new is the transformed coordinates, of shape (1, 3)
368
    - old is the original coordinates, of shape (1, 4)
369
    - T is the transformation matrix, of shape (4, 3)
370
371
    Given original and transformed coordinates,
372
    we can calculate the transformation matrix using
373
374
        x = np.linalg.lstsq(a, b)
375
376
    such that
377
378
        a x = b
379
380
    In our case,
381
382
    - a = old
383
    - b = new
384
    - x = T
385
386
    To generate random transformation,
387
    we choose to add random perturbation to corner coordinates as follows:
388
    for corner of coordinates (x, y, z), the noise is
389
390
        -(x, y, z) .* (r1, r2, r3)
391
392
    where ri is a random number between (0, scale).
393
    So
394
395
        (x', y', z') = (x, y, z) .* (1-r1, 1-r2, 1-r3)
396
397
    Thus, we can directly sample between 1-scale and 1 instead
398
399
    We choose to calculate the transformation based on
400
    four corners in a cube centered at (0, 0, 0).
401
    A cube is shown as below, where
402
403
    - C = (-1, -1, -1)
404
    - G = (-1, -1, 1)
405
    - D = (-1, 1, -1)
406
    - A = (1, -1, -1)
407
408
    .. code-block:: text
409
410
                    G — — — — — — — — H
411
                  / |               / |
412
                /   |             /   |
413
              /     |           /     |
414
            /       |         /       |
415
          /         |       /         |
416
        E — — — — — — — — F           |
417
        |           |     |           |
418
        |           |     |           |
419
        |           C — — | — — — — — D
420
        |         /       |         /
421
        |       /         |       /
422
        |     /           |     /
423
        |   /             |   /
424
        | /               | /
425
        A — — — — — — — — B
426
427
    :param batch_size: int
428
    :param scale: a float number between 0 and 1
429
    :param seed: control the randomness
430
    :return: shape = (batch, 4, 3)
431
    """
432
433
    assert 0 <= scale <= 1
434
    np.random.seed(seed)
435
    noise = np.random.uniform(1 - scale, 1, [batch_size, 4, 3])  # shape = (batch, 4, 3)
436
437
    # old represents four corners of a cube
438
    # corresponding to the corner C G D A as shown above
439
    old = np.tile(
440
        [[[-1, -1, -1, 1], [-1, -1, 1, 1], [-1, 1, -1, 1], [1, -1, -1, 1]]],
441
        [batch_size, 1, 1],
442
    )  # shape = (batch, 4, 4)
443
    new = old[:, :, :3] * noise  # shape = (batch, 4, 3)
444
445
    theta = np.array(
446
        [np.linalg.lstsq(old[k], new[k], rcond=-1)[0] for k in range(batch_size)]
447
    )  # shape = (batch, 4, 3)
448
449
    return tf.cast(theta, dtype=tf.float32)
450
451
452
def gen_rand_ddf(
453
    batch_size: int,
454
    image_size: tuple,
455
    field_strength: (tuple, list),
456
    low_res_size: (tuple, list),
457
    seed: (int, None) = None,
458
) -> tf.Tensor:
459
    """
460
    Function that generates a random 3D DDF for a batch of data.
461
462
    :param batch_size:
463
    :param image_size:
464
    :param field_strength: maximum field strength, computed as a U[0,field_strength]
465
    :param low_res_size: low_resolution deformation field that will be upsampled to
466
        the original size in order to get smooth and more realistic fields.
467
    :param seed: control the randomness
468
    :return:
469
    """
470
471
    np.random.seed(seed)
472
    low_res_strength = np.random.uniform(0, field_strength, (batch_size, 1, 1, 1, 3))
473
    low_res_field = low_res_strength * np.random.randn(
474
        batch_size, low_res_size[0], low_res_size[1], low_res_size[2], 3
475
    )
476
    high_res_field = resize3d(low_res_field, image_size)
477
    return high_res_field
478