Passed
Pull Request — main (#656)
by Yunguan
02:51
created

RandomTransformation3D.get_config()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 7
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""
2
Module containing data augmentation techniques.
3
  - 3D Affine/DDF Transforms for moving and fixed images.
4
"""
5
6
from abc import abstractmethod
7
from typing import Dict
8
9
import numpy as np
10
import tensorflow as tf
11
12
from deepreg.model.layer import Resize3d
13
from deepreg.model.layer_util import get_reference_grid, resample, warp_grid
14
from deepreg.registry import REGISTRY
15
16
17
class RandomTransformation3D(tf.keras.layers.Layer):
18
    """
19
    An interface for different types of transformation.
20
    """
21
22
    def __init__(
23
        self,
24
        moving_image_size: tuple,
25
        fixed_image_size: tuple,
26
        batch_size: int,
27
        name: str = "RandomTransformation3D",
28
        trainable: bool = False,
29
    ):
30
        """
31
        Abstract class for image transformation.
32
33
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
34
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
35
        :param batch_size: size of mini-batch
36
        :param name: name of layer
37
        :param trainable: if this layer is trainable
38
        """
39
        super().__init__(trainable=trainable, name=name)
40
        self.moving_image_size = moving_image_size
41
        self.fixed_image_size = fixed_image_size
42
        self.batch_size = batch_size
43
        self.moving_grid_ref = get_reference_grid(grid_size=moving_image_size)
44
        self.fixed_grid_ref = get_reference_grid(grid_size=fixed_image_size)
45
46
    @abstractmethod
47
    def gen_transform_params(self) -> (tf.Tensor, tf.Tensor):
48
        """
49
        Generates transformation parameters for moving and fixed image.
50
51
        :return: two tensors
52
        """
53
54
    @staticmethod
55
    @abstractmethod
56
    def transform(
57
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
58
    ) -> tf.Tensor:
59
        """
60
        Transforms the reference grid and then resample the image.
61
62
        :param image: shape = (batch, dim1, dim2, dim3)
63
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
64
        :param params: parameters for transformation
65
        :return: shape = (batch, dim1, dim2, dim3)
66
        """
67
68
    def call(self, inputs: Dict[str, tf.Tensor], **kwargs) -> Dict[str, tf.Tensor]:
69
        """
70
        Creates random params for the input images and their labels,
71
        and params them based on the resampled reference grids.
72
        :param inputs: a dict having multiple tensors
73
            if labeled:
74
                moving_image, shape = (batch, m_dim1, m_dim2, m_dim3)
75
                fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
76
                moving_label, shape = (batch, m_dim1, m_dim2, m_dim3)
77
                fixed_label, shape = (batch, f_dim1, f_dim2, f_dim3)
78
                indices, shape = (batch, num_indices)
79
            else, unlabeled:
80
                moving_image, shape = (batch, m_dim1, m_dim2, m_dim3)
81
                fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
82
                indices, shape = (batch, num_indices)
83
        :param kwargs: other arguments
84
        :return: dictionary with the same structure as inputs
85
        """
86
87
        moving_image = inputs["moving_image"]
88
        fixed_image = inputs["fixed_image"]
89
        indices = inputs["indices"]
90
91
        moving_params, fixed_params = self.gen_transform_params()
92
93
        moving_image = self.transform(moving_image, self.moving_grid_ref, moving_params)
94
        fixed_image = self.transform(fixed_image, self.fixed_grid_ref, fixed_params)
95
96
        if "moving_label" not in inputs:  # unlabeled
97
            return dict(
98
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
99
            )
100
        moving_label = inputs["moving_label"]
101
        fixed_label = inputs["fixed_label"]
102
103
        moving_label = self.transform(moving_label, self.moving_grid_ref, moving_params)
104
        fixed_label = self.transform(fixed_label, self.fixed_grid_ref, fixed_params)
105
106
        return dict(
107
            moving_image=moving_image,
108
            fixed_image=fixed_image,
109
            moving_label=moving_label,
110
            fixed_label=fixed_label,
111
            indices=indices,
112
        )
113
114
    def get_config(self) -> dict:
115
        """Return the config dictionary for recreating this class."""
116
        config = super().get_config()
117
        config["moving_image_size"] = self.moving_image_size
118
        config["fixed_image_size"] = self.fixed_image_size
119
        config["batch_size"] = self.batch_size
120
        return config
121
122
123
@REGISTRY.register_data_augmentation(name="affine")
124
class RandomAffineTransform3D(RandomTransformation3D):
125
    """Apply random affine transformation to moving/fixed images separately."""
126
127
    def __init__(
128
        self,
129
        moving_image_size: tuple,
130
        fixed_image_size: tuple,
131
        batch_size: int,
132
        scale: float = 0.1,
133
        name: str = "RandomAffineTransform3D",
134
        **kwargs,
135
    ):
136
        """
137
        Init.
138
139
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
140
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
141
        :param batch_size: size of mini-batch
142
        :param scale: a positive float controlling the scale of transformation
143
        :param name: name of the layer
144
        :param kwargs: extra arguments
145
        """
146
        super().__init__(
147
            moving_image_size=moving_image_size,
148
            fixed_image_size=fixed_image_size,
149
            batch_size=batch_size,
150
            name=name,
151
            **kwargs,
152
        )
153
        self.scale = scale
154
155
    def get_config(self) -> dict:
156
        """Return the config dictionary for recreating this class."""
157
        config = super().get_config()
158
        config["scale"] = self.scale
159
        return config
160
161
    def gen_transform_params(self) -> (tf.Tensor, tf.Tensor):
162
        """
163
        Function that generates the random 3D transformation parameters
164
        for a batch of data for moving and fixed image.
165
166
        :return: a tuple of tensors, each has shape = (batch, 4, 3)
167
        """
168
        theta = gen_rand_affine_transform(
169
            batch_size=self.batch_size * 2, scale=self.scale
170
        )
171
        return theta[: self.batch_size], theta[self.batch_size :]
172
173
    @staticmethod
174
    def transform(
175
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
176
    ) -> tf.Tensor:
177
        """
178
        Transforms the reference grid and then resample the image.
179
180
        :param image: shape = (batch, dim1, dim2, dim3)
181
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
182
        :param params: shape = (batch, 4, 3)
183
        :return: shape = (batch, dim1, dim2, dim3)
184
        """
185
        return resample(vol=image, loc=warp_grid(grid_ref, params))
186
187
188
@REGISTRY.register_data_augmentation(name="ddf")
189
class RandomDDFTransform3D(RandomTransformation3D):
190
    """Apply random DDF transformation to moving/fixed images separately."""
191
192
    def __init__(
193
        self,
194
        moving_image_size: tuple,
195
        fixed_image_size: tuple,
196
        batch_size: int,
197
        field_strength: int = 1,
198
        low_res_size: tuple = (1, 1, 1),
199
        name: str = "RandomDDFTransform3D",
200
        **kwargs,
201
    ):
202
        """
203
        Creates a DDF transformation for data augmentation.
204
205
        To simulate smooth deformation fields, we interpolate from a low resolution
206
        field of size low_res_size using linear interpolation. The variance of the
207
        deformation field is drawn from a uniform variable
208
        between [0, field_strength].
209
210
        :param moving_image_size: tuple
211
        :param fixed_image_size: tuple
212
        :param batch_size: int
213
        :param field_strength: int = 1. It is used as the upper bound for the
214
        deformation field variance
215
        :param low_res_size: tuple = (1, 1, 1).
216
        :param name: name of layer
217
        :param kwargs: extra arguments
218
        """
219
220
        super().__init__(
221
            moving_image_size=moving_image_size,
222
            fixed_image_size=fixed_image_size,
223
            batch_size=batch_size,
224
            name=name,
225
            **kwargs,
226
        )
227
228
        assert tuple(low_res_size) <= tuple(moving_image_size)
229
        assert tuple(low_res_size) <= tuple(fixed_image_size)
230
231
        self.field_strength = field_strength
232
        self.low_res_size = low_res_size
233
234
    def get_config(self) -> dict:
235
        """Return the config dictionary for recreating this class."""
236
        config = super().get_config()
237
        config["field_strength"] = self.field_strength
238
        config["low_res_size"] = self.low_res_size
239
        return config
240
241
    def gen_transform_params(self) -> (tf.Tensor, tf.Tensor):
242
        """
243
        Generates two random ddf fields for moving and fixed images.
244
245
        :return: tuple, one has shape = (batch, m_dim1, m_dim2, m_dim3, 3)
246
            another one has shape = (batch, f_dim1, f_dim2, f_dim3, 3)
247
        """
248
        kwargs = dict(
249
            batch_size=self.batch_size,
250
            field_strength=self.field_strength,
251
            low_res_size=self.low_res_size,
252
        )
253
        moving = gen_rand_ddf(image_size=self.moving_image_size, **kwargs)
254
        fixed = gen_rand_ddf(image_size=self.fixed_image_size, **kwargs)
255
        return moving, fixed
256
257
    @staticmethod
258
    def transform(
259
        image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
260
    ) -> tf.Tensor:
261
        """
262
        Transforms the reference grid and then resample the image.
263
264
        :param image: shape = (batch, dim1, dim2, dim3)
265
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
266
        :param params: DDF, shape = (batch, dim1, dim2, dim3, 3)
267
        :return: shape = (batch, dim1, dim2, dim3)
268
        """
269
        return resample(vol=image, loc=grid_ref[None, ...] + params)
270
271
272
def resize_inputs(
273
    inputs: Dict[str, tf.Tensor], moving_image_size: tuple, fixed_image_size: tuple
274
) -> Dict[str, tf.Tensor]:
275
    """
276
    Resize inputs
277
    :param inputs:
278
        if labeled:
279
            moving_image, shape = (None, None, None)
280
            fixed_image, shape = (None, None, None)
281
            moving_label, shape = (None, None, None)
282
            fixed_label, shape = (None, None, None)
283
            indices, shape = (num_indices, )
284
        else, unlabeled:
285
            moving_image, shape = (None, None, None)
286
            fixed_image, shape = (None, None, None)
287
            indices, shape = (num_indices, )
288
    :param moving_image_size: tuple, (m_dim1, m_dim2, m_dim3)
289
    :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
290
    :return:
291
        if labeled:
292
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
293
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
294
            moving_label, shape = (m_dim1, m_dim2, m_dim3)
295
            fixed_label, shape = (f_dim1, f_dim2, f_dim3)
296
            indices, shape = (num_indices, )
297
        else, unlabeled:
298
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
299
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
300
            indices, shape = (num_indices, )
301
    """
302
    moving_image = inputs["moving_image"]
303
    fixed_image = inputs["fixed_image"]
304
    indices = inputs["indices"]
305
306
    moving_resize_layer = Resize3d(shape=moving_image_size)
307
    fixed_resize_layer = Resize3d(shape=fixed_image_size)
308
309
    moving_image = moving_resize_layer(moving_image)
310
    fixed_image = fixed_resize_layer(fixed_image)
311
312
    if "moving_label" not in inputs:  # unlabeled
313
        return dict(moving_image=moving_image, fixed_image=fixed_image, indices=indices)
314
    moving_label = inputs["moving_label"]
315
    fixed_label = inputs["fixed_label"]
316
317
    moving_label = moving_resize_layer(moving_label)
318
    fixed_label = fixed_resize_layer(fixed_label)
319
320
    return dict(
321
        moving_image=moving_image,
322
        fixed_image=fixed_image,
323
        moving_label=moving_label,
324
        fixed_label=fixed_label,
325
        indices=indices,
326
    )
327
328
329
def gen_rand_affine_transform(
330
    batch_size: int, scale: float, seed: (int, None) = None
331
) -> tf.Tensor:
332
    """
333
    Function that generates a random 3D transformation parameters for a batch of data.
334
335
    for 3D coordinates, affine transformation is
336
337
    .. code-block:: text
338
339
        [[x' y' z' 1]] = [[x y z 1]] * [[* * * 0]
340
                                        [* * * 0]
341
                                        [* * * 0]
342
                                        [* * * 1]]
343
344
    where each * represents a degree of freedom,
345
    so there are in total 12 degrees of freedom
346
    the equation can be denoted as
347
348
        new = old * T
349
350
    where
351
352
    - new is the transformed coordinates, of shape (1, 4)
353
    - old is the original coordinates, of shape (1, 4)
354
    - T is the transformation matrix, of shape (4, 4)
355
356
    the equation can be simplified to
357
358
    .. code-block:: text
359
360
        [[x' y' z']] = [[x y z 1]] * [[* * *]
361
                                      [* * *]
362
                                      [* * *]
363
                                      [* * *]]
364
365
    so that
366
367
        new = old * T
368
369
    where
370
371
    - new is the transformed coordinates, of shape (1, 3)
372
    - old is the original coordinates, of shape (1, 4)
373
    - T is the transformation matrix, of shape (4, 3)
374
375
    Given original and transformed coordinates,
376
    we can calculate the transformation matrix using
377
378
        x = np.linalg.lstsq(a, b)
379
380
    such that
381
382
        a x = b
383
384
    In our case,
385
386
    - a = old
387
    - b = new
388
    - x = T
389
390
    To generate random transformation,
391
    we choose to add random perturbation to corner coordinates as follows:
392
    for corner of coordinates (x, y, z), the noise is
393
394
        -(x, y, z) .* (r1, r2, r3)
395
396
    where ri is a random number between (0, scale).
397
    So
398
399
        (x', y', z') = (x, y, z) .* (1-r1, 1-r2, 1-r3)
400
401
    Thus, we can directly sample between 1-scale and 1 instead
402
403
    We choose to calculate the transformation based on
404
    four corners in a cube centered at (0, 0, 0).
405
    A cube is shown as below, where
406
407
    - C = (-1, -1, -1)
408
    - G = (-1, -1, 1)
409
    - D = (-1, 1, -1)
410
    - A = (1, -1, -1)
411
412
    .. code-block:: text
413
414
                    G — — — — — — — — H
415
                  / |               / |
416
                /   |             /   |
417
              /     |           /     |
418
            /       |         /       |
419
          /         |       /         |
420
        E — — — — — — — — F           |
421
        |           |     |           |
422
        |           |     |           |
423
        |           C — — | — — — — — D
424
        |         /       |         /
425
        |       /         |       /
426
        |     /           |     /
427
        |   /             |   /
428
        | /               | /
429
        A — — — — — — — — B
430
431
    :param batch_size: int
432
    :param scale: a float number between 0 and 1
433
    :param seed: control the randomness
434
    :return: shape = (batch, 4, 3)
435
    """
436
437
    assert 0 <= scale <= 1
438
    np.random.seed(seed)
439
    noise = np.random.uniform(1 - scale, 1, [batch_size, 4, 3])  # shape = (batch, 4, 3)
440
441
    # old represents four corners of a cube
442
    # corresponding to the corner C G D A as shown above
443
    old = np.tile(
444
        [[[-1, -1, -1, 1], [-1, -1, 1, 1], [-1, 1, -1, 1], [1, -1, -1, 1]]],
445
        [batch_size, 1, 1],
446
    )  # shape = (batch, 4, 4)
447
    new = old[:, :, :3] * noise  # shape = (batch, 4, 3)
448
449
    theta = np.array(
450
        [np.linalg.lstsq(old[k], new[k], rcond=-1)[0] for k in range(batch_size)]
451
    )  # shape = (batch, 4, 3)
452
453
    return tf.cast(theta, dtype=tf.float32)
454
455
456
def gen_rand_ddf(
457
    batch_size: int,
458
    image_size: tuple,
459
    field_strength: (tuple, list),
460
    low_res_size: (tuple, list),
461
    seed: (int, None) = None,
462
) -> tf.Tensor:
463
    """
464
    Function that generates a random 3D DDF for a batch of data.
465
466
    :param batch_size:
467
    :param image_size:
468
    :param field_strength: maximum field strength, computed as a U[0,field_strength]
469
    :param low_res_size: low_resolution deformation field that will be upsampled to
470
        the original size in order to get smooth and more realistic fields.
471
    :param seed: control the randomness
472
    :return:
473
    """
474
475
    np.random.seed(seed)
476
    low_res_strength = np.random.uniform(0, field_strength, (batch_size, 1, 1, 1, 3))
477
    low_res_field = low_res_strength * np.random.randn(
478
        batch_size, low_res_size[0], low_res_size[1], low_res_size[2], 3
479
    )
480
    high_res_field = Resize3d(shape=image_size)(low_res_field)
481
    return high_res_field
482