RandomDDFTransform3D.__init__()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 41
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 19
dl 0
loc 41
rs 9.45
c 0
b 0
f 0
cc 1
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""
2
Module containing data augmentation techniques.
3
  - 3D Affine/DDF Transforms for moving and fixed images.
4
"""
5
6
from abc import abstractmethod
7
from typing import Dict, List, Optional, Tuple, Union
8
9
import numpy as np
10
import tensorflow as tf
11
12
from deepreg.model.layer import Resize3d
13
from deepreg.model.layer_util import get_reference_grid, resample, warp_grid
14
from deepreg.registry import REGISTRY
15
16
17
class RandomTransformation3D(tf.keras.layers.Layer):
18
    """
19
    An interface for different types of transformation.
20
    """
21
22
    def __init__(
23
        self,
24
        moving_image_size: Tuple[int, ...],
25
        fixed_image_size: Tuple[int, ...],
26
        batch_size: int,
27
        name: str = "RandomTransformation3D",
28
        trainable: bool = False,
29
    ):
30
        """
31
        Abstract class for image transformation.
32
33
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
34
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
35
        :param batch_size: total number of samples consumed per step, over all devices.
36
        :param name: name of layer
37
        :param trainable: if this layer is trainable
38
        """
39
        super().__init__(trainable=trainable, name=name)
40
        self.moving_image_size = moving_image_size
41
        self.fixed_image_size = fixed_image_size
42
        self.batch_size = batch_size
43
        self.moving_grid_ref = get_reference_grid(grid_size=moving_image_size)
44
        self.fixed_grid_ref = get_reference_grid(grid_size=fixed_image_size)
45
46
    @abstractmethod
47
    def gen_transform_params(self) -> Tuple[tf.Tensor, tf.Tensor]:
48
        """
49
        Generates transformation parameters for moving and fixed image.
50
51
        :return: two tensors
52
        """
53
54
    @staticmethod
55
    @abstractmethod
56
    def transform(
57
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
58
    ) -> tf.Tensor:
59
        """
60
        Transforms the reference grid and then resample the image.
61
62
        :param image: shape = (batch, dim1, dim2, dim3)
63
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
64
        :param params: parameters for transformation
65
        :return: shape = (batch, dim1, dim2, dim3)
66
        """
67
68
    def call(self, inputs: Dict[str, tf.Tensor], **kwargs) -> Dict[str, tf.Tensor]:
69
        """
70
        Creates random params for the input images and their labels,
71
        and params them based on the resampled reference grids.
72
        :param inputs: a dict having multiple tensors
73
            if labeled:
74
                moving_image, shape = (batch, m_dim1, m_dim2, m_dim3)
75
                fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
76
                moving_label, shape = (batch, m_dim1, m_dim2, m_dim3)
77
                fixed_label, shape = (batch, f_dim1, f_dim2, f_dim3)
78
                indices, shape = (batch, num_indices)
79
            else, unlabeled:
80
                moving_image, shape = (batch, m_dim1, m_dim2, m_dim3)
81
                fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
82
                indices, shape = (batch, num_indices)
83
        :param kwargs: other arguments
84
        :return: dictionary with the same structure as inputs
85
        """
86
87
        moving_image = inputs["moving_image"]
88
        fixed_image = inputs["fixed_image"]
89
        indices = inputs["indices"]
90
91
        moving_params, fixed_params = self.gen_transform_params()
92
93
        moving_image = self.transform(moving_image, self.moving_grid_ref, moving_params)
94
        fixed_image = self.transform(fixed_image, self.fixed_grid_ref, fixed_params)
95
96
        if "moving_label" not in inputs:  # unlabeled
97
            return dict(
98
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
99
            )
100
        moving_label = inputs["moving_label"]
101
        fixed_label = inputs["fixed_label"]
102
103
        moving_label = self.transform(moving_label, self.moving_grid_ref, moving_params)
104
        fixed_label = self.transform(fixed_label, self.fixed_grid_ref, fixed_params)
105
106
        return dict(
107
            moving_image=moving_image,
108
            fixed_image=fixed_image,
109
            moving_label=moving_label,
110
            fixed_label=fixed_label,
111
            indices=indices,
112
        )
113
114
    def get_config(self) -> dict:
115
        """Return the config dictionary for recreating this class."""
116
        config = super().get_config()
117
        config["moving_image_size"] = self.moving_image_size
118
        config["fixed_image_size"] = self.fixed_image_size
119
        config["batch_size"] = self.batch_size
120
        return config
121
122
123
@REGISTRY.register_data_augmentation(name="affine")
124
class RandomAffineTransform3D(RandomTransformation3D):
125
    """Apply random affine transformation to moving/fixed images separately."""
126
127
    def __init__(
128
        self,
129
        moving_image_size: Tuple[int, ...],
130
        fixed_image_size: Tuple[int, ...],
131
        batch_size: int,
132
        scale: float = 0.1,
133
        name: str = "RandomAffineTransform3D",
134
        **kwargs,
135
    ):
136
        """
137
        Init.
138
139
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
140
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
141
        :param batch_size: total number of samples consumed per step, over all devices.
142
        :param scale: a positive float controlling the scale of transformation
143
        :param name: name of the layer
144
        :param kwargs: additional arguments
145
        """
146
        super().__init__(
147
            moving_image_size=moving_image_size,
148
            fixed_image_size=fixed_image_size,
149
            batch_size=batch_size,
150
            name=name,
151
            **kwargs,
152
        )
153
        self.scale = scale
154
155
    def get_config(self) -> dict:
156
        """Return the config dictionary for recreating this class."""
157
        config = super().get_config()
158
        config["scale"] = self.scale
159
        return config
160
161
    def gen_transform_params(self) -> Tuple[tf.Tensor, tf.Tensor]:
162
        """
163
        Function that generates the random 3D transformation parameters
164
        for a batch of data for moving and fixed image.
165
166
        :return: a tuple of tensors, each has shape = (batch, 4, 3)
167
        """
168
        theta = gen_rand_affine_transform(
169
            batch_size=self.batch_size * 2, scale=self.scale
170
        )
171
        return theta[: self.batch_size], theta[self.batch_size :]
172
173
    @staticmethod
174
    def transform(
175
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
176
    ) -> tf.Tensor:
177
        """
178
        Transforms the reference grid and then resample the image.
179
180
        :param image: shape = (batch, dim1, dim2, dim3)
181
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
182
        :param params: shape = (batch, 4, 3)
183
        :return: shape = (batch, dim1, dim2, dim3)
184
        """
185
        return resample(vol=image, loc=warp_grid(grid_ref, params))
186
187
188
@REGISTRY.register_data_augmentation(name="ddf")
189
class RandomDDFTransform3D(RandomTransformation3D):
190
    """Apply random DDF transformation to moving/fixed images separately."""
191
192
    def __init__(
193
        self,
194
        moving_image_size: Tuple[int, ...],
195
        fixed_image_size: Tuple[int, ...],
196
        batch_size: int,
197
        field_strength: int = 1,
198
        low_res_size: tuple = (1, 1, 1),
199
        name: str = "RandomDDFTransform3D",
200
        **kwargs,
201
    ):
202
        """
203
        Creates a DDF transformation for data augmentation.
204
205
        To simulate smooth deformation fields, we interpolate from a low resolution
206
        field of size low_res_size using linear interpolation. The variance of the
207
        deformation field is drawn from a uniform variable
208
        between [0, field_strength].
209
210
        :param moving_image_size: tuple
211
        :param fixed_image_size: tuple
212
        :param batch_size: total number of samples consumed per step, over all devices.
213
        :param field_strength: int = 1. It is used as the upper bound for the
214
        deformation field variance
215
        :param low_res_size: tuple = (1, 1, 1).
216
        :param name: name of layer
217
        :param kwargs: additional arguments
218
        """
219
220
        super().__init__(
221
            moving_image_size=moving_image_size,
222
            fixed_image_size=fixed_image_size,
223
            batch_size=batch_size,
224
            name=name,
225
            **kwargs,
226
        )
227
228
        assert tuple(low_res_size) <= tuple(moving_image_size)
229
        assert tuple(low_res_size) <= tuple(fixed_image_size)
230
231
        self.field_strength = field_strength
232
        self.low_res_size = low_res_size
233
234
    def get_config(self) -> dict:
235
        """Return the config dictionary for recreating this class."""
236
        config = super().get_config()
237
        config["field_strength"] = self.field_strength
238
        config["low_res_size"] = self.low_res_size
239
        return config
240
241
    def gen_transform_params(self) -> Tuple[tf.Tensor, tf.Tensor]:
242
        """
243
        Generates two random ddf fields for moving and fixed images.
244
245
        :return: tuple, one has shape = (batch, m_dim1, m_dim2, m_dim3, 3)
246
            another one has shape = (batch, f_dim1, f_dim2, f_dim3, 3)
247
        """
248
        moving = gen_rand_ddf(
249
            image_size=self.moving_image_size,
250
            batch_size=self.batch_size,
251
            field_strength=self.field_strength,
252
            low_res_size=self.low_res_size,
253
        )
254
        fixed = gen_rand_ddf(
255
            image_size=self.fixed_image_size,
256
            batch_size=self.batch_size,
257
            field_strength=self.field_strength,
258
            low_res_size=self.low_res_size,
259
        )
260
        return moving, fixed
261
262
    @staticmethod
263
    def transform(
264
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
265
    ) -> tf.Tensor:
266
        """
267
        Transforms the reference grid and then resample the image.
268
269
        :param image: shape = (batch, dim1, dim2, dim3)
270
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
271
        :param params: DDF, shape = (batch, dim1, dim2, dim3, 3)
272
        :return: shape = (batch, dim1, dim2, dim3)
273
        """
274
        return resample(vol=image, loc=grid_ref[None, ...] + params)
275
276
277
def resize_inputs(
278
    inputs: Dict[str, tf.Tensor],
279
    moving_image_size: Tuple[int, ...],
280
    fixed_image_size: tuple,
281
) -> Dict[str, tf.Tensor]:
282
    """
283
    Resize inputs
284
    :param inputs:
285
        if labeled:
286
            moving_image, shape = (None, None, None)
287
            fixed_image, shape = (None, None, None)
288
            moving_label, shape = (None, None, None)
289
            fixed_label, shape = (None, None, None)
290
            indices, shape = (num_indices, )
291
        else, unlabeled:
292
            moving_image, shape = (None, None, None)
293
            fixed_image, shape = (None, None, None)
294
            indices, shape = (num_indices, )
295
    :param moving_image_size: Tuple[int, ...], (m_dim1, m_dim2, m_dim3)
296
    :param fixed_image_size: Tuple[int, ...], (f_dim1, f_dim2, f_dim3)
297
    :return:
298
        if labeled:
299
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
300
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
301
            moving_label, shape = (m_dim1, m_dim2, m_dim3)
302
            fixed_label, shape = (f_dim1, f_dim2, f_dim3)
303
            indices, shape = (num_indices, )
304
        else, unlabeled:
305
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
306
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
307
            indices, shape = (num_indices, )
308
    """
309
    moving_image = inputs["moving_image"]
310
    fixed_image = inputs["fixed_image"]
311
    indices = inputs["indices"]
312
313
    moving_resize_layer = Resize3d(shape=moving_image_size)
314
    fixed_resize_layer = Resize3d(shape=fixed_image_size)
315
316
    moving_image = moving_resize_layer(moving_image)
317
    fixed_image = fixed_resize_layer(fixed_image)
318
319
    if "moving_label" not in inputs:  # unlabeled
320
        return dict(moving_image=moving_image, fixed_image=fixed_image, indices=indices)
321
    moving_label = inputs["moving_label"]
322
    fixed_label = inputs["fixed_label"]
323
324
    moving_label = moving_resize_layer(moving_label)
325
    fixed_label = fixed_resize_layer(fixed_label)
326
327
    return dict(
328
        moving_image=moving_image,
329
        fixed_image=fixed_image,
330
        moving_label=moving_label,
331
        fixed_label=fixed_label,
332
        indices=indices,
333
    )
334
335
336
def gen_rand_affine_transform(
337
    batch_size: int, scale: float, seed: Optional[int] = None
338
) -> tf.Tensor:
339
    """
340
    Function that generates a random 3D transformation parameters for a batch of data.
341
342
    for 3D coordinates, affine transformation is
343
344
    .. code-block:: text
345
346
        [[x' y' z' 1]] = [[x y z 1]] * [[* * * 0]
347
                                        [* * * 0]
348
                                        [* * * 0]
349
                                        [* * * 1]]
350
351
    where each * represents a degree of freedom,
352
    so there are in total 12 degrees of freedom
353
    the equation can be denoted as
354
355
        new = old * T
356
357
    where
358
359
    - new is the transformed coordinates, of shape (1, 4)
360
    - old is the original coordinates, of shape (1, 4)
361
    - T is the transformation matrix, of shape (4, 4)
362
363
    the equation can be simplified to
364
365
    .. code-block:: text
366
367
        [[x' y' z']] = [[x y z 1]] * [[* * *]
368
                                      [* * *]
369
                                      [* * *]
370
                                      [* * *]]
371
372
    so that
373
374
        new = old * T
375
376
    where
377
378
    - new is the transformed coordinates, of shape (1, 3)
379
    - old is the original coordinates, of shape (1, 4)
380
    - T is the transformation matrix, of shape (4, 3)
381
382
    Given original and transformed coordinates,
383
    we can calculate the transformation matrix using
384
385
        x = np.linalg.lstsq(a, b)
386
387
    such that
388
389
        a x = b
390
391
    In our case,
392
393
    - a = old
394
    - b = new
395
    - x = T
396
397
    To generate random transformation,
398
    we choose to add random perturbation to corner coordinates as follows:
399
    for corner of coordinates (x, y, z), the noise is
400
401
        -(x, y, z) .* (r1, r2, r3)
402
403
    where ri is a random number between (0, scale).
404
    So
405
406
        (x', y', z') = (x, y, z) .* (1-r1, 1-r2, 1-r3)
407
408
    Thus, we can directly sample between 1-scale and 1 instead
409
410
    We choose to calculate the transformation based on
411
    four corners in a cube centered at (0, 0, 0).
412
    A cube is shown as below, where
413
414
    - C = (-1, -1, -1)
415
    - G = (-1, -1, 1)
416
    - D = (-1, 1, -1)
417
    - A = (1, -1, -1)
418
419
    .. code-block:: text
420
421
                    G — — — — — — — — H
422
                  / |               / |
423
                /   |             /   |
424
              /     |           /     |
425
            /       |         /       |
426
          /         |       /         |
427
        E — — — — — — — — F           |
428
        |           |     |           |
429
        |           |     |           |
430
        |           C — — | — — — — — D
431
        |         /       |         /
432
        |       /         |       /
433
        |     /           |     /
434
        |   /             |   /
435
        | /               | /
436
        A — — — — — — — — B
437
438
    :param batch_size: total number of samples consumed per step, over all devices.
439
    :param scale: a float number between 0 and 1
440
    :param seed: control the randomness
441
    :return: shape = (batch, 4, 3)
442
    """
443
444
    assert 0 <= scale <= 1
445
    np.random.seed(seed)
446
    noise = np.random.uniform(1 - scale, 1, [batch_size, 4, 3])  # shape = (batch, 4, 3)
447
448
    # old represents four corners of a cube
449
    # corresponding to the corner C G D A as shown above
450
    old = np.tile(
451
        [[[-1, -1, -1, 1], [-1, -1, 1, 1], [-1, 1, -1, 1], [1, -1, -1, 1]]],
452
        [batch_size, 1, 1],
453
    )  # shape = (batch, 4, 4)
454
    new = old[:, :, :3] * noise  # shape = (batch, 4, 3)
455
456
    theta = np.array(
457
        [np.linalg.lstsq(old[k], new[k], rcond=-1)[0] for k in range(batch_size)]
458
    )  # shape = (batch, 4, 3)
459
460
    return tf.cast(theta, dtype=tf.float32)
461
462
463
def gen_rand_ddf(
464
    batch_size: int,
465
    image_size: Tuple[int, ...],
466
    field_strength: Union[Tuple, List, int, float],
467
    low_res_size: Union[Tuple, List],
468
    seed: Optional[int] = None,
469
) -> tf.Tensor:
470
    """
471
    Function that generates a random 3D DDF for a batch of data.
472
473
    :param batch_size: total number of samples consumed per step, over all devices.
474
    :param image_size:
475
    :param field_strength: maximum field strength, computed as a U[0,field_strength]
476
    :param low_res_size: low_resolution deformation field that will be upsampled to
477
        the original size in order to get smooth and more realistic fields.
478
    :param seed: control the randomness
479
    :return:
480
    """
481
482
    np.random.seed(seed)
483
    low_res_strength = np.random.uniform(0, field_strength, (batch_size, 1, 1, 1, 3))
484
    low_res_field = low_res_strength * np.random.randn(
485
        batch_size, low_res_size[0], low_res_size[1], low_res_size[2], 3
486
    )
487
    high_res_field = Resize3d(shape=image_size)(low_res_field)
488
    return high_res_field
489