Passed
Pull Request — main (#656)
by Yunguan
02:36
created

RandomDDFTransform3D.__init__()   A

Complexity

Conditions 1

Size

Total Lines 41
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 19
dl 0
loc 41
rs 9.45
c 0
b 0
f 0
cc 1
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""
2
Module containing data augmentation techniques.
3
  - 3D Affine/DDF Transforms for moving and fixed images.
4
"""
5
6
from abc import abstractmethod
7
from typing import Dict
8
9
import numpy as np
10
import tensorflow as tf
11
12
from deepreg.model.layer_util import get_reference_grid, resample, resize3d, warp_grid
13
from deepreg.registry import REGISTRY
14
15
16
class RandomTransformation3D(tf.keras.layers.Layer):
17
    """
18
    An interface for different types of transformation.
19
    """
20
21
    def __init__(
22
        self,
23
        moving_image_size: tuple,
24
        fixed_image_size: tuple,
25
        batch_size: int,
26
        name: str = "RandomTransformation3D",
27
        trainable: bool = False,
28
    ):
29
        """
30
        Abstract class for image transformation.
31
32
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
33
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
34
        :param batch_size: size of mini-batch
35
        :param name: name of layer
36
        :param trainable: if this layer is trainable
37
        """
38
        super().__init__(trainable=trainable, name=name)
39
        self.moving_image_size = moving_image_size
40
        self.fixed_image_size = fixed_image_size
41
        self.batch_size = batch_size
42
        self.moving_grid_ref = get_reference_grid(grid_size=moving_image_size)
43
        self.fixed_grid_ref = get_reference_grid(grid_size=fixed_image_size)
44
45
    @abstractmethod
46
    def gen_transform_params(self) -> (tf.Tensor, tf.Tensor):
47
        """
48
        Generates transformation parameters for moving and fixed image.
49
50
        :return: two tensors
51
        """
52
53
    @staticmethod
54
    @abstractmethod
55
    def transform(
56
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
57
    ) -> tf.Tensor:
58
        """
59
        Transforms the reference grid and then resample the image.
60
61
        :param image: shape = (batch, dim1, dim2, dim3)
62
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
63
        :param params: parameters for transformation
64
        :return: shape = (batch, dim1, dim2, dim3)
65
        """
66
67
    def call(self, inputs: Dict[str, tf.Tensor], **kwargs) -> Dict[str, tf.Tensor]:
68
        """
69
        Creates random params for the input images and their labels,
70
        and params them based on the resampled reference grids.
71
        :param inputs: a dict having multiple tensors
72
            if labeled:
73
                moving_image, shape = (batch, m_dim1, m_dim2, m_dim3)
74
                fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
75
                moving_label, shape = (batch, m_dim1, m_dim2, m_dim3)
76
                fixed_label, shape = (batch, f_dim1, f_dim2, f_dim3)
77
                indices, shape = (batch, num_indices)
78
            else, unlabeled:
79
                moving_image, shape = (batch, m_dim1, m_dim2, m_dim3)
80
                fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
81
                indices, shape = (batch, num_indices)
82
        :param kwargs: other arguments
83
        :return: dictionary with the same structure as inputs
84
        """
85
86
        moving_image = inputs["moving_image"]
87
        fixed_image = inputs["fixed_image"]
88
        indices = inputs["indices"]
89
90
        moving_params, fixed_params = self.gen_transform_params()
91
92
        moving_image = self.transform(moving_image, self.moving_grid_ref, moving_params)
93
        fixed_image = self.transform(fixed_image, self.fixed_grid_ref, fixed_params)
94
95
        if "moving_label" not in inputs:  # unlabeled
96
            return dict(
97
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
98
            )
99
        moving_label = inputs["moving_label"]
100
        fixed_label = inputs["fixed_label"]
101
102
        moving_label = self.transform(moving_label, self.moving_grid_ref, moving_params)
103
        fixed_label = self.transform(fixed_label, self.fixed_grid_ref, fixed_params)
104
105
        return dict(
106
            moving_image=moving_image,
107
            fixed_image=fixed_image,
108
            moving_label=moving_label,
109
            fixed_label=fixed_label,
110
            indices=indices,
111
        )
112
113
    def get_config(self) -> dict:
114
        """Return the config dictionary for recreating this class."""
115
        config = super().get_config()
116
        config["moving_image_size"] = self.moving_image_size
117
        config["fixed_image_size"] = self.fixed_image_size
118
        config["batch_size"] = self.batch_size
119
        return config
120
121
122
@REGISTRY.register_data_augmentation(name="affine")
123
class RandomAffineTransform3D(RandomTransformation3D):
124
    """Apply random affine transformation to moving/fixed images separately."""
125
126
    def __init__(
127
        self,
128
        moving_image_size: tuple,
129
        fixed_image_size: tuple,
130
        batch_size: int,
131
        scale: float = 0.1,
132
        name: str = "RandomAffineTransform3D",
133
        **kwargs,
134
    ):
135
        """
136
        Init.
137
138
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
139
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
140
        :param batch_size: size of mini-batch
141
        :param scale: a positive float controlling the scale of transformation
142
        :param name: name of the layer
143
        :param kwargs: extra arguments
144
        """
145
        super().__init__(
146
            moving_image_size=moving_image_size,
147
            fixed_image_size=fixed_image_size,
148
            batch_size=batch_size,
149
            name=name,
150
            **kwargs,
151
        )
152
        self.scale = scale
153
154
    def get_config(self) -> dict:
155
        """Return the config dictionary for recreating this class."""
156
        config = super().get_config()
157
        config["scale"] = self.scale
158
        return config
159
160
    def gen_transform_params(self) -> (tf.Tensor, tf.Tensor):
161
        """
162
        Function that generates the random 3D transformation parameters
163
        for a batch of data for moving and fixed image.
164
165
        :return: a tuple of tensors, each has shape = (batch, 4, 3)
166
        """
167
        theta = gen_rand_affine_transform(
168
            batch_size=self.batch_size * 2, scale=self.scale
169
        )
170
        return theta[: self.batch_size], theta[self.batch_size :]
171
172
    @staticmethod
173
    def transform(
174
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
175
    ) -> tf.Tensor:
176
        """
177
        Transforms the reference grid and then resample the image.
178
179
        :param image: shape = (batch, dim1, dim2, dim3)
180
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
181
        :param params: shape = (batch, 4, 3)
182
        :return: shape = (batch, dim1, dim2, dim3)
183
        """
184
        return resample(vol=image, loc=warp_grid(grid_ref, params))
185
186
187
@REGISTRY.register_data_augmentation(name="ddf")
188
class RandomDDFTransform3D(RandomTransformation3D):
189
    """Apply random DDF transformation to moving/fixed images separately."""
190
191
    def __init__(
192
        self,
193
        moving_image_size: tuple,
194
        fixed_image_size: tuple,
195
        batch_size: int,
196
        field_strength: int = 1,
197
        low_res_size: tuple = (1, 1, 1),
198
        name: str = "RandomDDFTransform3D",
199
        **kwargs,
200
    ):
201
        """
202
        Creates a DDF transformation for data augmentation.
203
204
        To simulate smooth deformation fields, we interpolate from a low resolution
205
        field of size low_res_size using linear interpolation. The variance of the
206
        deformation field is drawn from a uniform variable
207
        between [0, field_strength].
208
209
        :param moving_image_size: tuple
210
        :param fixed_image_size: tuple
211
        :param batch_size: int
212
        :param field_strength: int = 1. It is used as the upper bound for the
213
        deformation field variance
214
        :param low_res_size: tuple = (1, 1, 1).
215
        :param name: name of layer
216
        :param kwargs: extra arguments
217
        """
218
219
        super().__init__(
220
            moving_image_size=moving_image_size,
221
            fixed_image_size=fixed_image_size,
222
            batch_size=batch_size,
223
            name=name,
224
            **kwargs,
225
        )
226
227
        assert tuple(low_res_size) <= tuple(moving_image_size)
228
        assert tuple(low_res_size) <= tuple(fixed_image_size)
229
230
        self.field_strength = field_strength
231
        self.low_res_size = low_res_size
232
233
    def get_config(self) -> dict:
234
        """Return the config dictionary for recreating this class."""
235
        config = super().get_config()
236
        config["field_strength"] = self.field_strength
237
        config["low_res_size"] = self.low_res_size
238
        return config
239
240
    def gen_transform_params(self) -> (tf.Tensor, tf.Tensor):
241
        """
242
        Generates two random ddf fields for moving and fixed images.
243
244
        :return: tuple, one has shape = (batch, m_dim1, m_dim2, m_dim3, 3)
245
            another one has shape = (batch, f_dim1, f_dim2, f_dim3, 3)
246
        """
247
        kwargs = dict(
248
            batch_size=self.batch_size,
249
            field_strength=self.field_strength,
250
            low_res_size=self.low_res_size,
251
        )
252
        moving = gen_rand_ddf(image_size=self.moving_image_size, **kwargs)
253
        fixed = gen_rand_ddf(image_size=self.fixed_image_size, **kwargs)
254
        return moving, fixed
255
256
    @staticmethod
257
    def transform(
258
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
259
    ) -> tf.Tensor:
260
        """
261
        Transforms the reference grid and then resample the image.
262
263
        :param image: shape = (batch, dim1, dim2, dim3)
264
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
265
        :param params: DDF, shape = (batch, dim1, dim2, dim3, 3)
266
        :return: shape = (batch, dim1, dim2, dim3)
267
        """
268
        return resample(vol=image, loc=grid_ref[None, ...] + params)
269
270
271
def resize_inputs(
272
    inputs: Dict[str, tf.Tensor], moving_image_size: tuple, fixed_image_size: tuple
273
) -> Dict[str, tf.Tensor]:
274
    """
275
    Resize inputs
276
    :param inputs:
277
        if labeled:
278
            moving_image, shape = (None, None, None)
279
            fixed_image, shape = (None, None, None)
280
            moving_label, shape = (None, None, None)
281
            fixed_label, shape = (None, None, None)
282
            indices, shape = (num_indices, )
283
        else, unlabeled:
284
            moving_image, shape = (None, None, None)
285
            fixed_image, shape = (None, None, None)
286
            indices, shape = (num_indices, )
287
    :param moving_image_size: tuple, (m_dim1, m_dim2, m_dim3)
288
    :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
289
    :return:
290
        if labeled:
291
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
292
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
293
            moving_label, shape = (m_dim1, m_dim2, m_dim3)
294
            fixed_label, shape = (f_dim1, f_dim2, f_dim3)
295
            indices, shape = (num_indices, )
296
        else, unlabeled:
297
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
298
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
299
            indices, shape = (num_indices, )
300
    """
301
    moving_image = inputs["moving_image"]
302
    fixed_image = inputs["fixed_image"]
303
    indices = inputs["indices"]
304
305
    moving_image = resize3d(image=moving_image, size=moving_image_size)
306
    fixed_image = resize3d(image=fixed_image, size=fixed_image_size)
307
308
    if "moving_label" not in inputs:  # unlabeled
309
        return dict(moving_image=moving_image, fixed_image=fixed_image, indices=indices)
310
    moving_label = inputs["moving_label"]
311
    fixed_label = inputs["fixed_label"]
312
313
    moving_label = resize3d(image=moving_label, size=moving_image_size)
314
    fixed_label = resize3d(image=fixed_label, size=fixed_image_size)
315
316
    return dict(
317
        moving_image=moving_image,
318
        fixed_image=fixed_image,
319
        moving_label=moving_label,
320
        fixed_label=fixed_label,
321
        indices=indices,
322
    )
323
324
325
def gen_rand_affine_transform(
326
    batch_size: int, scale: float, seed: (int, None) = None
327
) -> tf.Tensor:
328
    """
329
    Function that generates a random 3D transformation parameters for a batch of data.
330
331
    for 3D coordinates, affine transformation is
332
333
    .. code-block:: text
334
335
        [[x' y' z' 1]] = [[x y z 1]] * [[* * * 0]
336
                                        [* * * 0]
337
                                        [* * * 0]
338
                                        [* * * 1]]
339
340
    where each * represents a degree of freedom,
341
    so there are in total 12 degrees of freedom
342
    the equation can be denoted as
343
344
        new = old * T
345
346
    where
347
348
    - new is the transformed coordinates, of shape (1, 4)
349
    - old is the original coordinates, of shape (1, 4)
350
    - T is the transformation matrix, of shape (4, 4)
351
352
    the equation can be simplified to
353
354
    .. code-block:: text
355
356
        [[x' y' z']] = [[x y z 1]] * [[* * *]
357
                                      [* * *]
358
                                      [* * *]
359
                                      [* * *]]
360
361
    so that
362
363
        new = old * T
364
365
    where
366
367
    - new is the transformed coordinates, of shape (1, 3)
368
    - old is the original coordinates, of shape (1, 4)
369
    - T is the transformation matrix, of shape (4, 3)
370
371
    Given original and transformed coordinates,
372
    we can calculate the transformation matrix using
373
374
        x = np.linalg.lstsq(a, b)
375
376
    such that
377
378
        a x = b
379
380
    In our case,
381
382
    - a = old
383
    - b = new
384
    - x = T
385
386
    To generate random transformation,
387
    we choose to add random perturbation to corner coordinates as follows:
388
    for corner of coordinates (x, y, z), the noise is
389
390
        -(x, y, z) .* (r1, r2, r3)
391
392
    where ri is a random number between (0, scale).
393
    So
394
395
        (x', y', z') = (x, y, z) .* (1-r1, 1-r2, 1-r3)
396
397
    Thus, we can directly sample between 1-scale and 1 instead
398
399
    We choose to calculate the transformation based on
400
    four corners in a cube centered at (0, 0, 0).
401
    A cube is shown as below, where
402
403
    - C = (-1, -1, -1)
404
    - G = (-1, -1, 1)
405
    - D = (-1, 1, -1)
406
    - A = (1, -1, -1)
407
408
    .. code-block:: text
409
410
                    G — — — — — — — — H
411
                  / |               / |
412
                /   |             /   |
413
              /     |           /     |
414
            /       |         /       |
415
          /         |       /         |
416
        E — — — — — — — — F           |
417
        |           |     |           |
418
        |           |     |           |
419
        |           C — — | — — — — — D
420
        |         /       |         /
421
        |       /         |       /
422
        |     /           |     /
423
        |   /             |   /
424
        | /               | /
425
        A — — — — — — — — B
426
427
    :param batch_size: int
428
    :param scale: a float number between 0 and 1
429
    :param seed: control the randomness
430
    :return: shape = (batch, 4, 3)
431
    """
432
433
    assert 0 <= scale <= 1
434
    np.random.seed(seed)
435
    noise = np.random.uniform(1 - scale, 1, [batch_size, 4, 3])  # shape = (batch, 4, 3)
436
437
    # old represents four corners of a cube
438
    # corresponding to the corner C G D A as shown above
439
    old = np.tile(
440
        [[[-1, -1, -1, 1], [-1, -1, 1, 1], [-1, 1, -1, 1], [1, -1, -1, 1]]],
441
        [batch_size, 1, 1],
442
    )  # shape = (batch, 4, 4)
443
    new = old[:, :, :3] * noise  # shape = (batch, 4, 3)
444
445
    theta = np.array(
446
        [np.linalg.lstsq(old[k], new[k], rcond=-1)[0] for k in range(batch_size)]
447
    )  # shape = (batch, 4, 3)
448
449
    return tf.cast(theta, dtype=tf.float32)
450
451
452
def gen_rand_ddf(
453
    batch_size: int,
454
    image_size: tuple,
455
    field_strength: (tuple, list),
456
    low_res_size: (tuple, list),
457
    seed: (int, None) = None,
458
) -> tf.Tensor:
459
    """
460
    Function that generates a random 3D DDF for a batch of data.
461
462
    :param batch_size:
463
    :param image_size:
464
    :param field_strength: maximum field strength, computed as a U[0,field_strength]
465
    :param low_res_size: low_resolution deformation field that will be upsampled to
466
        the original size in order to get smooth and more realistic fields.
467
    :param seed: control the randomness
468
    :return:
469
    """
470
471
    np.random.seed(seed)
472
    low_res_strength = np.random.uniform(0, field_strength, (batch_size, 1, 1, 1, 3))
473
    low_res_field = low_res_strength * np.random.randn(
474
        batch_size, low_res_size[0], low_res_size[1], low_res_size[2], 3
475
    )
476
    high_res_field = resize3d(low_res_field, image_size)
477
    return high_res_field
478