torchio.transforms.augmentation.spatial.random_elastic_deformation   A
last analyzed

Complexity

Total Complexity 28

Size/Duplication

Total Lines 326
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 28
eloc 169
dl 0
loc 326
rs 10
c 0
b 0
f 0

2 Functions

Rating   Name   Duplication   Size   Complexity  
A _parse_max_displacement() 0 11 4
A _parse_num_control_points() 0 10 4

8 Methods

Rating   Name   Duplication   Size   Complexity  
A ElasticDeformation.parse_free_form_transform() 0 18 2
A RandomElasticDeformation.__init__() 0 26 4
A RandomElasticDeformation.get_params() 0 23 3
A ElasticDeformation.apply_bspline_transform() 0 28 2
A ElasticDeformation.apply_transform() 0 17 4
A ElasticDeformation.get_bspline_transform() 0 16 3
A RandomElasticDeformation.apply_transform() 0 17 1
A ElasticDeformation.__init__() 0 17 1
1
import warnings
2
from numbers import Number
3
from typing import Tuple, Union, Sequence
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....utils import to_tuple
10
from ....data.io import nib_to_sitk
11
from ....data.subject import Subject
12
from ....data.image import ScalarImage
13
from ....typing import TypeTripletInt, TypeTripletFloat
14
from ... import SpatialTransform
15
from .. import RandomTransform
16
17
18
SPLINE_ORDER = 3
19
20
21
class RandomElasticDeformation(RandomTransform, SpatialTransform):
22
    r"""Apply dense random elastic deformation.
23
24
    A random displacement is assigned to a coarse grid of control points around
25
    and inside the image. The displacement at each voxel is interpolated from
26
    the coarse grid using cubic B-splines.
27
28
    The `'Deformable Registration' <https://www.sciencedirect.com/topics/computer-science/deformable-registration>`_
29
    topic on ScienceDirect contains useful articles explaining interpolation of
30
    displacement fields using cubic B-splines.
31
32
    .. warning:: This transform is slow as it requires expensive computations.
33
        If your images are large you might want to use
34
        :class:`~torchio.transforms.RandomAffine` instead.
35
36
    Args:
37
        num_control_points: Number of control points along each dimension of
38
            the coarse grid :math:`(n_x, n_y, n_z)`.
39
            If a single value :math:`n` is passed,
40
            then :math:`n_x = n_y = n_z = n`.
41
            Smaller numbers generate smoother deformations.
42
            The minimum number of control points is ``4`` as this transform
43
            uses cubic B-splines to interpolate displacement.
44
        max_displacement: Maximum displacement along each dimension at each
45
            control point :math:`(D_x, D_y, D_z)`.
46
            The displacement along dimension :math:`i` at each control point is
47
            :math:`d_i \sim \mathcal{U}(0, D_i)`.
48
            If a single value :math:`D` is passed,
49
            then :math:`D_x = D_y = D_z = D`.
50
            Note that the total maximum displacement would actually be
51
            :math:`D_{max} = \sqrt{D_x^2 + D_y^2 + D_z^2}`.
52
        locked_borders: If ``0``, all displacement vectors are kept.
53
            If ``1``, displacement of control points at the
54
            border of the coarse grid will be set to ``0``.
55
            If ``2``, displacement of control points at the border of the image
56
            (red dots in the image below) will also be set to ``0``.
57
        image_interpolation: See :ref:`Interpolation`.
58
            Note that this is the interpolation used to compute voxel
59
            intensities when resampling using the dense displacement field.
60
            The value of the dense displacement at each voxel is always
61
            interpolated with cubic B-splines from the values at the control
62
            points of the coarse grid.
63
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
64
            keyword arguments.
65
66
    `This gist <https://gist.github.com/fepegar/b723d15de620cd2a3a4dbd71e491b59d>`_
67
    can also be used to better understand the meaning of the parameters.
68
69
    This is an example from the
70
    `3D Slicer registration FAQ <https://www.slicer.org/wiki/Documentation/4.10/FAQ/Registration#What.27s_the_BSpline_Grid_Size.3F>`_.
71
72
    .. image:: https://www.slicer.org/w/img_auth.php/6/6f/RegLib_BSplineGridModel.png
73
        :alt: B-spline example from 3D Slicer documentation
74
75
    To generate a similar grid of control points with TorchIO,
76
    the transform can be instantiated as follows::
77
78
        >>> from torchio import RandomElasticDeformation
79
        >>> transform = RandomElasticDeformation(
80
        ...     num_control_points=(7, 7, 7),  # or just 7
81
        ...     locked_borders=2,
82
        ... )
83
84
    Note that control points outside the image bounds are not showed in the
85
    example image (they would also be red as we set :attr:`locked_borders`
86
    to ``2``).
87
88
    .. warning:: Image folding may occur if the maximum displacement is larger
89
        than half the coarse grid spacing. The grid spacing can be computed
90
        using the image bounds in physical space [#]_ and the number of control
91
        points::
92
93
            >>> import numpy as np
94
            >>> import torchio as tio
95
            >>> image = tio.datasets.Slicer().MRHead.as_sitk()
96
            >>> image.GetSize()  # in voxels
97
            (256, 256, 130)
98
            >>> image.GetSpacing()  # in mm
99
            (1.0, 1.0, 1.2999954223632812)
100
            >>> bounds = np.array(image.GetSize()) * np.array(image.GetSpacing())
101
            >>> bounds  # mm
102
            array([256.        , 256.        , 168.99940491])
103
            >>> num_control_points = np.array((7, 7, 6))
104
            >>> grid_spacing = bounds / (num_control_points - 2)
105
            >>> grid_spacing
106
            array([51.2       , 51.2       , 42.24985123])
107
            >>> potential_folding = grid_spacing / 2
108
            >>> potential_folding  # mm
109
            array([25.6       , 25.6       , 21.12492561])
110
111
        Using a :attr:`max_displacement` larger than the computed
112
        :attr:`potential_folding` will raise a :class:`RuntimeWarning`.
113
114
        .. [#] Technically, :math:`2 \epsilon` should be added to the
115
            image bounds, where :math:`\epsilon = 2^{-3}` `according to ITK
116
            source code <https://github.com/InsightSoftwareConsortium/ITK/blob/633f84548311600845d54ab2463d3412194690a8/Modules/Core/Transform/include/itkBSplineTransformInitializer.hxx#L116-L138>`_.
117
    """  # noqa: E501
118
119
    def __init__(
120
            self,
121
            num_control_points: Union[int, Tuple[int, int, int]] = 7,
122
            max_displacement: Union[float, Tuple[float, float, float]] = 7.5,
123
            locked_borders: int = 2,
124
            image_interpolation: str = 'linear',
125
            **kwargs
126
            ):
127
        super().__init__(**kwargs)
128
        self._bspline_transformation = None
129
        self.num_control_points = to_tuple(num_control_points, length=3)
130
        _parse_num_control_points(self.num_control_points)
131
        self.max_displacement = to_tuple(max_displacement, length=3)
132
        _parse_max_displacement(self.max_displacement)
133
        self.num_locked_borders = locked_borders
134
        if locked_borders not in (0, 1, 2):
135
            raise ValueError('locked_borders must be 0, 1, or 2')
136
        if locked_borders == 2 and 4 in self.num_control_points:
137
            message = (
138
                'Setting locked_borders to 2 and using less than 5 control'
139
                'points results in an identity transform. Lock fewer borders'
140
                ' or use more control points.'
141
            )
142
            raise ValueError(message)
143
        self.image_interpolation = self.parse_interpolation(
144
            image_interpolation)
145
146
    @staticmethod
147
    def get_params(
148
            num_control_points: TypeTripletInt,
149
            max_displacement: Tuple[float, float, float],
150
            num_locked_borders: int,
151
            ) -> np.ndarray:
152
        grid_shape = num_control_points
153
        num_dimensions = 3
154
        coarse_field = torch.rand(*grid_shape, num_dimensions)  # [0, 1)
155
        coarse_field -= 0.5  # [-0.5, 0.5)
156
        coarse_field *= 2  # [-1, 1]
157
        for dimension in range(3):
158
            # [-max_displacement, max_displacement)
159
            coarse_field[..., dimension] *= max_displacement[dimension]
160
161
        # Set displacement to 0 at the borders
162
        for i in range(num_locked_borders):
163
            coarse_field[i, :] = 0
164
            coarse_field[-1 - i, :] = 0
165
            coarse_field[:, i] = 0
166
            coarse_field[:, -1 - i] = 0
167
168
        return coarse_field.numpy()
169
170
    def apply_transform(self, subject: Subject) -> Subject:
171
        subject.check_consistent_spatial_shape()
172
        control_points = self.get_params(
173
            self.num_control_points,
174
            self.max_displacement,
175
            self.num_locked_borders,
176
        )
177
178
        arguments = {
179
            'control_points': control_points,
180
            'max_displacement': self.max_displacement,
181
            'image_interpolation': self.image_interpolation,
182
        }
183
184
        transform = ElasticDeformation(**self.add_include_exclude(arguments))
185
        transformed = transform(subject)
186
        return transformed
187
188
189
class ElasticDeformation(SpatialTransform):
190
    r"""Apply dense elastic deformation.
191
192
    Args:
193
        control_points:
194
        max_displacement:
195
        image_interpolation: See :ref:`Interpolation`.
196
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
197
            keyword arguments.
198
    """
199
200
    def __init__(
201
            self,
202
            control_points: np.ndarray,
203
            max_displacement: TypeTripletFloat,
204
            image_interpolation: str = 'linear',
205
            **kwargs
206
            ):
207
        super().__init__(**kwargs)
208
        self.control_points = control_points
209
        self.max_displacement = max_displacement
210
        self.image_interpolation = self.parse_interpolation(
211
            image_interpolation)
212
        self.invert_transform = False
213
        self.args_names = (
214
            'control_points',
215
            'image_interpolation',
216
            'max_displacement',
217
        )
218
219
    def get_bspline_transform(
220
            self,
221
            image: sitk.Image,
222
            ) -> sitk.BSplineTransformInitializer:
223
        control_points = self.control_points.copy()
224
        if self.invert_transform:
225
            control_points *= -1
226
        is_2d = image.GetSize()[2] == 1
227
        if is_2d:
228
            control_points[..., -1] = 0  # no displacement in IS axis
229
        num_control_points = control_points.shape[:-1]
230
        mesh_shape = [n - SPLINE_ORDER for n in num_control_points]
231
        bspline_transform = sitk.BSplineTransformInitializer(image, mesh_shape)
232
        parameters = control_points.flatten(order='F').tolist()
233
        bspline_transform.SetParameters(parameters)
234
        return bspline_transform
235
236
    @staticmethod
237
    def parse_free_form_transform(
238
            transform: sitk.Transform,
239
            max_displacement: Sequence[TypeTripletInt],
240
            ) -> None:
241
        """Issue a warning is possible folding is detected."""
242
        coefficient_images = transform.GetCoefficientImages()
243
        grid_spacing = coefficient_images[0].GetSpacing()
244
        conflicts = np.array(max_displacement) > np.array(grid_spacing) / 2
245
        if np.any(conflicts):
246
            where, = np.where(conflicts)
247
            message = (
248
                'The maximum displacement is larger than the coarse grid'
249
                f' spacing for dimensions: {where.tolist()}, so folding may'
250
                ' occur. Choose fewer control points or a smaller'
251
                ' maximum displacement'
252
            )
253
            warnings.warn(message, RuntimeWarning)
254
255
    def apply_transform(self, subject: Subject) -> Subject:
256
        no_displacement = not any(self.max_displacement)
257
        if no_displacement:
258
            return subject
259
        subject.check_consistent_spatial_shape()
260
        for image in self.get_images(subject):
261
            if not isinstance(image, ScalarImage):
262
                interpolation = 'nearest'
263
            else:
264
                interpolation = self.image_interpolation
265
            transformed = self.apply_bspline_transform(
266
                image.data,
267
                image.affine,
268
                interpolation,
269
            )
270
            image.set_data(transformed)
271
        return subject
272
273
    def apply_bspline_transform(
274
            self,
275
            tensor: torch.Tensor,
276
            affine: np.ndarray,
277
            interpolation: str,
278
            ) -> torch.Tensor:
279
        assert tensor.dim() == 4
280
        results = []
281
        for component in tensor:
282
            image = nib_to_sitk(component[np.newaxis], affine, force_3d=True)
283
            floating = reference = image
284
            bspline_transform = self.get_bspline_transform(image)
285
            self.parse_free_form_transform(
286
                bspline_transform,
287
                self.max_displacement,
288
            )
289
            interpolator = self.get_sitk_interpolator(interpolation)
290
            resampler = sitk.ResampleImageFilter()
291
            resampler.SetReferenceImage(reference)
292
            resampler.SetTransform(bspline_transform)
293
            resampler.SetInterpolator(interpolator)
294
            resampler.SetDefaultPixelValue(component.min().item())
295
            resampler.SetOutputPixelType(sitk.sitkFloat32)
296
            resampled = resampler.Execute(floating)
297
            result, _ = self.sitk_to_nib(resampled)
298
            results.append(torch.as_tensor(result))
299
        tensor = torch.cat(results)
300
        return tensor
301
302
303
def _parse_num_control_points(
304
        num_control_points: TypeTripletInt,
305
        ) -> None:
306
    for axis, number in enumerate(num_control_points):
307
        if not isinstance(number, int) or number < 4:
308
            message = (
309
                f'The number of control points for axis {axis} must be'
310
                f' an integer greater than 3, not {number}'
311
            )
312
            raise ValueError(message)
313
314
315
def _parse_max_displacement(
316
        max_displacement: Tuple[float, float, float],
317
        ) -> None:
318
    for axis, number in enumerate(max_displacement):
319
        if not isinstance(number, Number) or number < 0:
320
            message = (
321
                'The maximum displacement at each control point'
322
                f' for axis {axis} must be'
323
                f' a number greater or equal to 0, not {number}'
324
            )
325
            raise ValueError(message)
326