gen_rand_affine_transform()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 125
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 12
dl 0
loc 125
rs 9.8
c 0
b 0
f 0
cc 1
nop 3

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""
2
Module containing data augmentation techniques.
3
  - 3D Affine/DDF Transforms for moving and fixed images.
4
"""
5
6
from abc import abstractmethod
7
from typing import Dict, List, Optional, Tuple, Union
8
9
import numpy as np
10
import tensorflow as tf
11
12
from deepreg.model.layer import Resize3d
13
from deepreg.model.layer_util import get_reference_grid, resample, warp_grid
14
from deepreg.registry import REGISTRY
15
16
17
class RandomTransformation3D(tf.keras.layers.Layer):
18
    """
19
    An interface for different types of transformation.
20
    """
21
22
    def __init__(
23
        self,
24
        moving_image_size: Tuple[int, ...],
25
        fixed_image_size: Tuple[int, ...],
26
        batch_size: int,
27
        name: str = "RandomTransformation3D",
28
        trainable: bool = False,
29
    ):
30
        """
31
        Abstract class for image transformation.
32
33
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
34
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
35
        :param batch_size: total number of samples consumed per step, over all devices.
36
        :param name: name of layer
37
        :param trainable: if this layer is trainable
38
        """
39
        super().__init__(trainable=trainable, name=name)
40
        self.moving_image_size = moving_image_size
41
        self.fixed_image_size = fixed_image_size
42
        self.batch_size = batch_size
43
        self.moving_grid_ref = get_reference_grid(grid_size=moving_image_size)
44
        self.fixed_grid_ref = get_reference_grid(grid_size=fixed_image_size)
45
46
    @abstractmethod
47
    def gen_transform_params(self) -> Tuple[tf.Tensor, tf.Tensor]:
48
        """
49
        Generates transformation parameters for moving and fixed image.
50
51
        :return: two tensors
52
        """
53
54
    @staticmethod
55
    @abstractmethod
56
    def transform(
57
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
58
    ) -> tf.Tensor:
59
        """
60
        Transforms the reference grid and then resample the image.
61
62
        :param image: shape = (batch, dim1, dim2, dim3)
63
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
64
        :param params: parameters for transformation
65
        :return: shape = (batch, dim1, dim2, dim3)
66
        """
67
68
    def call(self, inputs: Dict[str, tf.Tensor], **kwargs) -> Dict[str, tf.Tensor]:
69
        """
70
        Creates random params for the input images and their labels,
71
        and params them based on the resampled reference grids.
72
        :param inputs: a dict having multiple tensors
73
            if labeled:
74
                moving_image, shape = (batch, m_dim1, m_dim2, m_dim3)
75
                fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
76
                moving_label, shape = (batch, m_dim1, m_dim2, m_dim3)
77
                fixed_label, shape = (batch, f_dim1, f_dim2, f_dim3)
78
                indices, shape = (batch, num_indices)
79
            else, unlabeled:
80
                moving_image, shape = (batch, m_dim1, m_dim2, m_dim3)
81
                fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
82
                indices, shape = (batch, num_indices)
83
        :param kwargs: other arguments
84
        :return: dictionary with the same structure as inputs
85
        """
86
87
        moving_image = inputs["moving_image"]
88
        fixed_image = inputs["fixed_image"]
89
        indices = inputs["indices"]
90
91
        moving_params, fixed_params = self.gen_transform_params()
92
93
        moving_image = self.transform(moving_image, self.moving_grid_ref, moving_params)
94
        fixed_image = self.transform(fixed_image, self.fixed_grid_ref, fixed_params)
95
96
        if "moving_label" not in inputs:  # unlabeled
97
            return dict(
98
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
99
            )
100
        moving_label = inputs["moving_label"]
101
        fixed_label = inputs["fixed_label"]
102
103
        moving_label = self.transform(moving_label, self.moving_grid_ref, moving_params)
104
        fixed_label = self.transform(fixed_label, self.fixed_grid_ref, fixed_params)
105
106
        return dict(
107
            moving_image=moving_image,
108
            fixed_image=fixed_image,
109
            moving_label=moving_label,
110
            fixed_label=fixed_label,
111
            indices=indices,
112
        )
113
114
    def get_config(self) -> dict:
115
        """Return the config dictionary for recreating this class."""
116
        config = super().get_config()
117
        config["moving_image_size"] = self.moving_image_size
118
        config["fixed_image_size"] = self.fixed_image_size
119
        config["batch_size"] = self.batch_size
120
        return config
121
122
123
@REGISTRY.register_data_augmentation(name="affine")
124
class RandomAffineTransform3D(RandomTransformation3D):
125
    """Apply random affine transformation to moving/fixed images separately."""
126
127
    def __init__(
128
        self,
129
        moving_image_size: Tuple[int, ...],
130
        fixed_image_size: Tuple[int, ...],
131
        batch_size: int,
132
        scale: float = 0.1,
133
        name: str = "RandomAffineTransform3D",
134
        **kwargs,
135
    ):
136
        """
137
        Init.
138
139
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
140
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
141
        :param batch_size: total number of samples consumed per step, over all devices.
142
        :param scale: a positive float controlling the scale of transformation
143
        :param name: name of the layer
144
        :param kwargs: additional arguments
145
        """
146
        super().__init__(
147
            moving_image_size=moving_image_size,
148
            fixed_image_size=fixed_image_size,
149
            batch_size=batch_size,
150
            name=name,
151
            **kwargs,
152
        )
153
        self.scale = scale
154
155
    def get_config(self) -> dict:
156
        """Return the config dictionary for recreating this class."""
157
        config = super().get_config()
158
        config["scale"] = self.scale
159
        return config
160
161
    def gen_transform_params(self) -> Tuple[tf.Tensor, tf.Tensor]:
162
        """
163
        Function that generates the random 3D transformation parameters
164
        for a batch of data for moving and fixed image.
165
166
        :return: a tuple of tensors, each has shape = (batch, 4, 3)
167
        """
168
        theta = gen_rand_affine_transform(
169
            batch_size=self.batch_size * 2, scale=self.scale
170
        )
171
        return theta[: self.batch_size], theta[self.batch_size :]
172
173
    @staticmethod
174
    def transform(
175
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
176
    ) -> tf.Tensor:
177
        """
178
        Transforms the reference grid and then resample the image.
179
180
        :param image: shape = (batch, dim1, dim2, dim3)
181
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
182
        :param params: shape = (batch, 4, 3)
183
        :return: shape = (batch, dim1, dim2, dim3)
184
        """
185
        return resample(vol=image, loc=warp_grid(grid_ref, params))
186
187
188
@REGISTRY.register_data_augmentation(name="ddf")
189
class RandomDDFTransform3D(RandomTransformation3D):
190
    """Apply random DDF transformation to moving/fixed images separately."""
191
192
    def __init__(
193
        self,
194
        moving_image_size: Tuple[int, ...],
195
        fixed_image_size: Tuple[int, ...],
196
        batch_size: int,
197
        field_strength: int = 1,
198
        low_res_size: tuple = (1, 1, 1),
199
        name: str = "RandomDDFTransform3D",
200
        **kwargs,
201
    ):
202
        """
203
        Creates a DDF transformation for data augmentation.
204
205
        To simulate smooth deformation fields, we interpolate from a low resolution
206
        field of size low_res_size using linear interpolation. The variance of the
207
        deformation field is drawn from a uniform variable
208
        between [0, field_strength].
209
210
        :param moving_image_size: tuple
211
        :param fixed_image_size: tuple
212
        :param batch_size: total number of samples consumed per step, over all devices.
213
        :param field_strength: int = 1. It is used as the upper bound for the
214
        deformation field variance
215
        :param low_res_size: tuple = (1, 1, 1).
216
        :param name: name of layer
217
        :param kwargs: additional arguments
218
        """
219
220
        super().__init__(
221
            moving_image_size=moving_image_size,
222
            fixed_image_size=fixed_image_size,
223
            batch_size=batch_size,
224
            name=name,
225
            **kwargs,
226
        )
227
228
        assert tuple(low_res_size) <= tuple(moving_image_size)
229
        assert tuple(low_res_size) <= tuple(fixed_image_size)
230
231
        self.field_strength = field_strength
232
        self.low_res_size = low_res_size
233
234
    def get_config(self) -> dict:
235
        """Return the config dictionary for recreating this class."""
236
        config = super().get_config()
237
        config["field_strength"] = self.field_strength
238
        config["low_res_size"] = self.low_res_size
239
        return config
240
241
    def gen_transform_params(self) -> Tuple[tf.Tensor, tf.Tensor]:
242
        """
243
        Generates two random ddf fields for moving and fixed images.
244
245
        :return: tuple, one has shape = (batch, m_dim1, m_dim2, m_dim3, 3)
246
            another one has shape = (batch, f_dim1, f_dim2, f_dim3, 3)
247
        """
248
        moving = gen_rand_ddf(
249
            image_size=self.moving_image_size,
250
            batch_size=self.batch_size,
251
            field_strength=self.field_strength,
252
            low_res_size=self.low_res_size,
253
        )
254
        fixed = gen_rand_ddf(
255
            image_size=self.fixed_image_size,
256
            batch_size=self.batch_size,
257
            field_strength=self.field_strength,
258
            low_res_size=self.low_res_size,
259
        )
260
        return moving, fixed
261
262
    @staticmethod
263
    def transform(
264
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
265
    ) -> tf.Tensor:
266
        """
267
        Transforms the reference grid and then resample the image.
268
269
        :param image: shape = (batch, dim1, dim2, dim3)
270
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
271
        :param params: DDF, shape = (batch, dim1, dim2, dim3, 3)
272
        :return: shape = (batch, dim1, dim2, dim3)
273
        """
274
        return resample(vol=image, loc=grid_ref[None, ...] + params)
275
276
277
def resize_inputs(
278
    inputs: Dict[str, tf.Tensor],
279
    moving_image_size: Tuple[int, ...],
280
    fixed_image_size: tuple,
281
) -> Dict[str, tf.Tensor]:
282
    """
283
    Resize inputs
284
    :param inputs:
285
        if labeled:
286
            moving_image, shape = (None, None, None)
287
            fixed_image, shape = (None, None, None)
288
            moving_label, shape = (None, None, None)
289
            fixed_label, shape = (None, None, None)
290
            indices, shape = (num_indices, )
291
        else, unlabeled:
292
            moving_image, shape = (None, None, None)
293
            fixed_image, shape = (None, None, None)
294
            indices, shape = (num_indices, )
295
    :param moving_image_size: Tuple[int, ...], (m_dim1, m_dim2, m_dim3)
296
    :param fixed_image_size: Tuple[int, ...], (f_dim1, f_dim2, f_dim3)
297
    :return:
298
        if labeled:
299
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
300
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
301
            moving_label, shape = (m_dim1, m_dim2, m_dim3)
302
            fixed_label, shape = (f_dim1, f_dim2, f_dim3)
303
            indices, shape = (num_indices, )
304
        else, unlabeled:
305
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
306
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
307
            indices, shape = (num_indices, )
308
    """
309
    moving_image = inputs["moving_image"]
310
    fixed_image = inputs["fixed_image"]
311
    indices = inputs["indices"]
312
313
    moving_resize_layer = Resize3d(shape=moving_image_size)
314
    fixed_resize_layer = Resize3d(shape=fixed_image_size)
315
316
    moving_image = moving_resize_layer(moving_image)
317
    fixed_image = fixed_resize_layer(fixed_image)
318
319
    if "moving_label" not in inputs:  # unlabeled
320
        return dict(moving_image=moving_image, fixed_image=fixed_image, indices=indices)
321
    moving_label = inputs["moving_label"]
322
    fixed_label = inputs["fixed_label"]
323
324
    moving_label = moving_resize_layer(moving_label)
325
    fixed_label = fixed_resize_layer(fixed_label)
326
327
    return dict(
328
        moving_image=moving_image,
329
        fixed_image=fixed_image,
330
        moving_label=moving_label,
331
        fixed_label=fixed_label,
332
        indices=indices,
333
    )
334
335
336
def gen_rand_affine_transform(
337
    batch_size: int, scale: float, seed: Optional[int] = None
338
) -> tf.Tensor:
339
    """
340
    Function that generates a random 3D transformation parameters for a batch of data.
341
342
    for 3D coordinates, affine transformation is
343
344
    .. code-block:: text
345
346
        [[x' y' z' 1]] = [[x y z 1]] * [[* * * 0]
347
                                        [* * * 0]
348
                                        [* * * 0]
349
                                        [* * * 1]]
350
351
    where each * represents a degree of freedom,
352
    so there are in total 12 degrees of freedom
353
    the equation can be denoted as
354
355
        new = old * T
356
357
    where
358
359
    - new is the transformed coordinates, of shape (1, 4)
360
    - old is the original coordinates, of shape (1, 4)
361
    - T is the transformation matrix, of shape (4, 4)
362
363
    the equation can be simplified to
364
365
    .. code-block:: text
366
367
        [[x' y' z']] = [[x y z 1]] * [[* * *]
368
                                      [* * *]
369
                                      [* * *]
370
                                      [* * *]]
371
372
    so that
373
374
        new = old * T
375
376
    where
377
378
    - new is the transformed coordinates, of shape (1, 3)
379
    - old is the original coordinates, of shape (1, 4)
380
    - T is the transformation matrix, of shape (4, 3)
381
382
    Given original and transformed coordinates,
383
    we can calculate the transformation matrix using
384
385
        x = np.linalg.lstsq(a, b)
386
387
    such that
388
389
        a x = b
390
391
    In our case,
392
393
    - a = old
394
    - b = new
395
    - x = T
396
397
    To generate random transformation,
398
    we choose to add random perturbation to corner coordinates as follows:
399
    for corner of coordinates (x, y, z), the noise is
400
401
        -(x, y, z) .* (r1, r2, r3)
402
403
    where ri is a random number between (0, scale).
404
    So
405
406
        (x', y', z') = (x, y, z) .* (1-r1, 1-r2, 1-r3)
407
408
    Thus, we can directly sample between 1-scale and 1 instead
409
410
    We choose to calculate the transformation based on
411
    four corners in a cube centered at (0, 0, 0).
412
    A cube is shown as below, where
413
414
    - C = (-1, -1, -1)
415
    - G = (-1, -1, 1)
416
    - D = (-1, 1, -1)
417
    - A = (1, -1, -1)
418
419
    .. code-block:: text
420
421
                    G — — — — — — — — H
422
                  / |               / |
423
                /   |             /   |
424
              /     |           /     |
425
            /       |         /       |
426
          /         |       /         |
427
        E — — — — — — — — F           |
428
        |           |     |           |
429
        |           |     |           |
430
        |           C — — | — — — — — D
431
        |         /       |         /
432
        |       /         |       /
433
        |     /           |     /
434
        |   /             |   /
435
        | /               | /
436
        A — — — — — — — — B
437
438
    :param batch_size: total number of samples consumed per step, over all devices.
439
    :param scale: a float number between 0 and 1
440
    :param seed: control the randomness
441
    :return: shape = (batch, 4, 3)
442
    """
443
444
    assert 0 <= scale <= 1
445
    np.random.seed(seed)
446
    noise = np.random.uniform(1 - scale, 1, [batch_size, 4, 3])  # shape = (batch, 4, 3)
447
448
    # old represents four corners of a cube
449
    # corresponding to the corner C G D A as shown above
450
    old = np.tile(
451
        [[[-1, -1, -1, 1], [-1, -1, 1, 1], [-1, 1, -1, 1], [1, -1, -1, 1]]],
452
        [batch_size, 1, 1],
453
    )  # shape = (batch, 4, 4)
454
    new = old[:, :, :3] * noise  # shape = (batch, 4, 3)
455
456
    theta = np.array(
457
        [np.linalg.lstsq(old[k], new[k], rcond=-1)[0] for k in range(batch_size)]
458
    )  # shape = (batch, 4, 3)
459
460
    return tf.cast(theta, dtype=tf.float32)
461
462
463
def gen_rand_ddf(
464
    batch_size: int,
465
    image_size: Tuple[int, ...],
466
    field_strength: Union[Tuple, List, int, float],
467
    low_res_size: Union[Tuple, List],
468
    seed: Optional[int] = None,
469
) -> tf.Tensor:
470
    """
471
    Function that generates a random 3D DDF for a batch of data.
472
473
    :param batch_size: total number of samples consumed per step, over all devices.
474
    :param image_size:
475
    :param field_strength: maximum field strength, computed as a U[0,field_strength]
476
    :param low_res_size: low_resolution deformation field that will be upsampled to
477
        the original size in order to get smooth and more realistic fields.
478
    :param seed: control the randomness
479
    :return:
480
    """
481
482
    np.random.seed(seed)
483
    low_res_strength = np.random.uniform(0, field_strength, (batch_size, 1, 1, 1, 3))
484
    low_res_field = low_res_strength * np.random.randn(
485
        batch_size, low_res_size[0], low_res_size[1], low_res_size[2], 3
486
    )
487
    high_res_field = Resize3d(shape=image_size)(low_res_field)
488
    return high_res_field
489