Passed
Pull Request — main (#656)
by Yunguan
02:51
created

RandomDDFTransform3D.__init__()   A

Complexity

Conditions 1

Size

Total Lines 41
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 19
dl 0
loc 41
rs 9.45
c 0
b 0
f 0
cc 1
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""
2
Module containing data augmentation techniques.
3
  - 3D Affine/DDF Transforms for moving and fixed images.
4
"""
5
6
from abc import abstractmethod
7
from typing import Dict
8
9
import numpy as np
10
import tensorflow as tf
11
12
from deepreg.model.layer import Resize3d
13
from deepreg.model.layer_util import get_reference_grid, resample, warp_grid
14
from deepreg.registry import REGISTRY
15
16
17
class RandomTransformation3D(tf.keras.layers.Layer):
18
    """
19
    An interface for different types of transformation.
20
    """
21
22
    def __init__(
23
        self,
24
        moving_image_size: tuple,
25
        fixed_image_size: tuple,
26
        batch_size: int,
27
        name: str = "RandomTransformation3D",
28
        trainable: bool = False,
29
    ):
30
        """
31
        Abstract class for image transformation.
32
33
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
34
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
35
        :param batch_size: size of mini-batch
36
        :param name: name of layer
37
        :param trainable: if this layer is trainable
38
        """
39
        super().__init__(trainable=trainable, name=name)
40
        self.moving_image_size = moving_image_size
41
        self.fixed_image_size = fixed_image_size
42
        self.batch_size = batch_size
43
        self.moving_grid_ref = get_reference_grid(grid_size=moving_image_size)
44
        self.fixed_grid_ref = get_reference_grid(grid_size=fixed_image_size)
45
46
    @abstractmethod
47
    def gen_transform_params(self) -> (tf.Tensor, tf.Tensor):
48
        """
49
        Generates transformation parameters for moving and fixed image.
50
51
        :return: two tensors
52
        """
53
54
    @staticmethod
55
    @abstractmethod
56
    def transform(
57
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
58
    ) -> tf.Tensor:
59
        """
60
        Transforms the reference grid and then resample the image.
61
62
        :param image: shape = (batch, dim1, dim2, dim3)
63
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
64
        :param params: parameters for transformation
65
        :return: shape = (batch, dim1, dim2, dim3)
66
        """
67
68
    def call(self, inputs: Dict[str, tf.Tensor], **kwargs) -> Dict[str, tf.Tensor]:
69
        """
70
        Creates random params for the input images and their labels,
71
        and params them based on the resampled reference grids.
72
        :param inputs: a dict having multiple tensors
73
            if labeled:
74
                moving_image, shape = (batch, m_dim1, m_dim2, m_dim3)
75
                fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
76
                moving_label, shape = (batch, m_dim1, m_dim2, m_dim3)
77
                fixed_label, shape = (batch, f_dim1, f_dim2, f_dim3)
78
                indices, shape = (batch, num_indices)
79
            else, unlabeled:
80
                moving_image, shape = (batch, m_dim1, m_dim2, m_dim3)
81
                fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
82
                indices, shape = (batch, num_indices)
83
        :param kwargs: other arguments
84
        :return: dictionary with the same structure as inputs
85
        """
86
87
        moving_image = inputs["moving_image"]
88
        fixed_image = inputs["fixed_image"]
89
        indices = inputs["indices"]
90
91
        moving_params, fixed_params = self.gen_transform_params()
92
93
        moving_image = self.transform(moving_image, self.moving_grid_ref, moving_params)
94
        fixed_image = self.transform(fixed_image, self.fixed_grid_ref, fixed_params)
95
96
        if "moving_label" not in inputs:  # unlabeled
97
            return dict(
98
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
99
            )
100
        moving_label = inputs["moving_label"]
101
        fixed_label = inputs["fixed_label"]
102
103
        moving_label = self.transform(moving_label, self.moving_grid_ref, moving_params)
104
        fixed_label = self.transform(fixed_label, self.fixed_grid_ref, fixed_params)
105
106
        return dict(
107
            moving_image=moving_image,
108
            fixed_image=fixed_image,
109
            moving_label=moving_label,
110
            fixed_label=fixed_label,
111
            indices=indices,
112
        )
113
114
    def get_config(self) -> dict:
115
        """Return the config dictionary for recreating this class."""
116
        config = super().get_config()
117
        config["moving_image_size"] = self.moving_image_size
118
        config["fixed_image_size"] = self.fixed_image_size
119
        config["batch_size"] = self.batch_size
120
        return config
121
122
123
@REGISTRY.register_data_augmentation(name="affine")
124
class RandomAffineTransform3D(RandomTransformation3D):
125
    """Apply random affine transformation to moving/fixed images separately."""
126
127
    def __init__(
128
        self,
129
        moving_image_size: tuple,
130
        fixed_image_size: tuple,
131
        batch_size: int,
132
        scale: float = 0.1,
133
        name: str = "RandomAffineTransform3D",
134
        **kwargs,
135
    ):
136
        """
137
        Init.
138
139
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
140
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
141
        :param batch_size: size of mini-batch
142
        :param scale: a positive float controlling the scale of transformation
143
        :param name: name of the layer
144
        :param kwargs: extra arguments
145
        """
146
        super().__init__(
147
            moving_image_size=moving_image_size,
148
            fixed_image_size=fixed_image_size,
149
            batch_size=batch_size,
150
            name=name,
151
            **kwargs,
152
        )
153
        self.scale = scale
154
155
    def get_config(self) -> dict:
156
        """Return the config dictionary for recreating this class."""
157
        config = super().get_config()
158
        config["scale"] = self.scale
159
        return config
160
161
    def gen_transform_params(self) -> (tf.Tensor, tf.Tensor):
162
        """
163
        Function that generates the random 3D transformation parameters
164
        for a batch of data for moving and fixed image.
165
166
        :return: a tuple of tensors, each has shape = (batch, 4, 3)
167
        """
168
        theta = gen_rand_affine_transform(
169
            batch_size=self.batch_size * 2, scale=self.scale
170
        )
171
        return theta[: self.batch_size], theta[self.batch_size :]
172
173
    @staticmethod
174
    def transform(
175
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
176
    ) -> tf.Tensor:
177
        """
178
        Transforms the reference grid and then resample the image.
179
180
        :param image: shape = (batch, dim1, dim2, dim3)
181
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
182
        :param params: shape = (batch, 4, 3)
183
        :return: shape = (batch, dim1, dim2, dim3)
184
        """
185
        return resample(vol=image, loc=warp_grid(grid_ref, params))
186
187
188
@REGISTRY.register_data_augmentation(name="ddf")
189
class RandomDDFTransform3D(RandomTransformation3D):
190
    """Apply random DDF transformation to moving/fixed images separately."""
191
192
    def __init__(
193
        self,
194
        moving_image_size: tuple,
195
        fixed_image_size: tuple,
196
        batch_size: int,
197
        field_strength: int = 1,
198
        low_res_size: tuple = (1, 1, 1),
199
        name: str = "RandomDDFTransform3D",
200
        **kwargs,
201
    ):
202
        """
203
        Creates a DDF transformation for data augmentation.
204
205
        To simulate smooth deformation fields, we interpolate from a low resolution
206
        field of size low_res_size using linear interpolation. The variance of the
207
        deformation field is drawn from a uniform variable
208
        between [0, field_strength].
209
210
        :param moving_image_size: tuple
211
        :param fixed_image_size: tuple
212
        :param batch_size: int
213
        :param field_strength: int = 1. It is used as the upper bound for the
214
        deformation field variance
215
        :param low_res_size: tuple = (1, 1, 1).
216
        :param name: name of layer
217
        :param kwargs: extra arguments
218
        """
219
220
        super().__init__(
221
            moving_image_size=moving_image_size,
222
            fixed_image_size=fixed_image_size,
223
            batch_size=batch_size,
224
            name=name,
225
            **kwargs,
226
        )
227
228
        assert tuple(low_res_size) <= tuple(moving_image_size)
229
        assert tuple(low_res_size) <= tuple(fixed_image_size)
230
231
        self.field_strength = field_strength
232
        self.low_res_size = low_res_size
233
234
    def get_config(self) -> dict:
235
        """Return the config dictionary for recreating this class."""
236
        config = super().get_config()
237
        config["field_strength"] = self.field_strength
238
        config["low_res_size"] = self.low_res_size
239
        return config
240
241
    def gen_transform_params(self) -> (tf.Tensor, tf.Tensor):
242
        """
243
        Generates two random ddf fields for moving and fixed images.
244
245
        :return: tuple, one has shape = (batch, m_dim1, m_dim2, m_dim3, 3)
246
            another one has shape = (batch, f_dim1, f_dim2, f_dim3, 3)
247
        """
248
        kwargs = dict(
249
            batch_size=self.batch_size,
250
            field_strength=self.field_strength,
251
            low_res_size=self.low_res_size,
252
        )
253
        moving = gen_rand_ddf(image_size=self.moving_image_size, **kwargs)
254
        fixed = gen_rand_ddf(image_size=self.fixed_image_size, **kwargs)
255
        return moving, fixed
256
257
    @staticmethod
258
    def transform(
259
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
260
    ) -> tf.Tensor:
261
        """
262
        Transforms the reference grid and then resample the image.
263
264
        :param image: shape = (batch, dim1, dim2, dim3)
265
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
266
        :param params: DDF, shape = (batch, dim1, dim2, dim3, 3)
267
        :return: shape = (batch, dim1, dim2, dim3)
268
        """
269
        return resample(vol=image, loc=grid_ref[None, ...] + params)
270
271
272
def resize_inputs(
273
    inputs: Dict[str, tf.Tensor], moving_image_size: tuple, fixed_image_size: tuple
274
) -> Dict[str, tf.Tensor]:
275
    """
276
    Resize inputs
277
    :param inputs:
278
        if labeled:
279
            moving_image, shape = (None, None, None)
280
            fixed_image, shape = (None, None, None)
281
            moving_label, shape = (None, None, None)
282
            fixed_label, shape = (None, None, None)
283
            indices, shape = (num_indices, )
284
        else, unlabeled:
285
            moving_image, shape = (None, None, None)
286
            fixed_image, shape = (None, None, None)
287
            indices, shape = (num_indices, )
288
    :param moving_image_size: tuple, (m_dim1, m_dim2, m_dim3)
289
    :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
290
    :return:
291
        if labeled:
292
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
293
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
294
            moving_label, shape = (m_dim1, m_dim2, m_dim3)
295
            fixed_label, shape = (f_dim1, f_dim2, f_dim3)
296
            indices, shape = (num_indices, )
297
        else, unlabeled:
298
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
299
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
300
            indices, shape = (num_indices, )
301
    """
302
    moving_image = inputs["moving_image"]
303
    fixed_image = inputs["fixed_image"]
304
    indices = inputs["indices"]
305
306
    moving_resize_layer = Resize3d(shape=moving_image_size)
307
    fixed_resize_layer = Resize3d(shape=fixed_image_size)
308
309
    moving_image = moving_resize_layer(moving_image)
310
    fixed_image = fixed_resize_layer(fixed_image)
311
312
    if "moving_label" not in inputs:  # unlabeled
313
        return dict(moving_image=moving_image, fixed_image=fixed_image, indices=indices)
314
    moving_label = inputs["moving_label"]
315
    fixed_label = inputs["fixed_label"]
316
317
    moving_label = moving_resize_layer(moving_label)
318
    fixed_label = fixed_resize_layer(fixed_label)
319
320
    return dict(
321
        moving_image=moving_image,
322
        fixed_image=fixed_image,
323
        moving_label=moving_label,
324
        fixed_label=fixed_label,
325
        indices=indices,
326
    )
327
328
329
def gen_rand_affine_transform(
330
    batch_size: int, scale: float, seed: (int, None) = None
331
) -> tf.Tensor:
332
    """
333
    Function that generates a random 3D transformation parameters for a batch of data.
334
335
    for 3D coordinates, affine transformation is
336
337
    .. code-block:: text
338
339
        [[x' y' z' 1]] = [[x y z 1]] * [[* * * 0]
340
                                        [* * * 0]
341
                                        [* * * 0]
342
                                        [* * * 1]]
343
344
    where each * represents a degree of freedom,
345
    so there are in total 12 degrees of freedom
346
    the equation can be denoted as
347
348
        new = old * T
349
350
    where
351
352
    - new is the transformed coordinates, of shape (1, 4)
353
    - old is the original coordinates, of shape (1, 4)
354
    - T is the transformation matrix, of shape (4, 4)
355
356
    the equation can be simplified to
357
358
    .. code-block:: text
359
360
        [[x' y' z']] = [[x y z 1]] * [[* * *]
361
                                      [* * *]
362
                                      [* * *]
363
                                      [* * *]]
364
365
    so that
366
367
        new = old * T
368
369
    where
370
371
    - new is the transformed coordinates, of shape (1, 3)
372
    - old is the original coordinates, of shape (1, 4)
373
    - T is the transformation matrix, of shape (4, 3)
374
375
    Given original and transformed coordinates,
376
    we can calculate the transformation matrix using
377
378
        x = np.linalg.lstsq(a, b)
379
380
    such that
381
382
        a x = b
383
384
    In our case,
385
386
    - a = old
387
    - b = new
388
    - x = T
389
390
    To generate random transformation,
391
    we choose to add random perturbation to corner coordinates as follows:
392
    for corner of coordinates (x, y, z), the noise is
393
394
        -(x, y, z) .* (r1, r2, r3)
395
396
    where ri is a random number between (0, scale).
397
    So
398
399
        (x', y', z') = (x, y, z) .* (1-r1, 1-r2, 1-r3)
400
401
    Thus, we can directly sample between 1-scale and 1 instead
402
403
    We choose to calculate the transformation based on
404
    four corners in a cube centered at (0, 0, 0).
405
    A cube is shown as below, where
406
407
    - C = (-1, -1, -1)
408
    - G = (-1, -1, 1)
409
    - D = (-1, 1, -1)
410
    - A = (1, -1, -1)
411
412
    .. code-block:: text
413
414
                    G — — — — — — — — H
415
                  / |               / |
416
                /   |             /   |
417
              /     |           /     |
418
            /       |         /       |
419
          /         |       /         |
420
        E — — — — — — — — F           |
421
        |           |     |           |
422
        |           |     |           |
423
        |           C — — | — — — — — D
424
        |         /       |         /
425
        |       /         |       /
426
        |     /           |     /
427
        |   /             |   /
428
        | /               | /
429
        A — — — — — — — — B
430
431
    :param batch_size: int
432
    :param scale: a float number between 0 and 1
433
    :param seed: control the randomness
434
    :return: shape = (batch, 4, 3)
435
    """
436
437
    assert 0 <= scale <= 1
438
    np.random.seed(seed)
439
    noise = np.random.uniform(1 - scale, 1, [batch_size, 4, 3])  # shape = (batch, 4, 3)
440
441
    # old represents four corners of a cube
442
    # corresponding to the corner C G D A as shown above
443
    old = np.tile(
444
        [[[-1, -1, -1, 1], [-1, -1, 1, 1], [-1, 1, -1, 1], [1, -1, -1, 1]]],
445
        [batch_size, 1, 1],
446
    )  # shape = (batch, 4, 4)
447
    new = old[:, :, :3] * noise  # shape = (batch, 4, 3)
448
449
    theta = np.array(
450
        [np.linalg.lstsq(old[k], new[k], rcond=-1)[0] for k in range(batch_size)]
451
    )  # shape = (batch, 4, 3)
452
453
    return tf.cast(theta, dtype=tf.float32)
454
455
456
def gen_rand_ddf(
457
    batch_size: int,
458
    image_size: tuple,
459
    field_strength: (tuple, list),
460
    low_res_size: (tuple, list),
461
    seed: (int, None) = None,
462
) -> tf.Tensor:
463
    """
464
    Function that generates a random 3D DDF for a batch of data.
465
466
    :param batch_size:
467
    :param image_size:
468
    :param field_strength: maximum field strength, computed as a U[0,field_strength]
469
    :param low_res_size: low_resolution deformation field that will be upsampled to
470
        the original size in order to get smooth and more realistic fields.
471
    :param seed: control the randomness
472
    :return:
473
    """
474
475
    np.random.seed(seed)
476
    low_res_strength = np.random.uniform(0, field_strength, (batch_size, 1, 1, 1, 3))
477
    low_res_field = low_res_strength * np.random.randn(
478
        batch_size, low_res_size[0], low_res_size[1], low_res_size[2], 3
479
    )
480
    high_res_field = Resize3d(shape=image_size)(low_res_field)
481
    return high_res_field
482