1
|
|
|
import warnings |
2
|
|
|
from numbers import Number |
3
|
|
|
from typing import Tuple, Union, Sequence |
4
|
|
|
|
5
|
|
|
import torch |
6
|
|
|
import numpy as np |
7
|
|
|
import SimpleITK as sitk |
8
|
|
|
|
9
|
|
|
from ....utils import to_tuple |
10
|
|
|
from ....data.io import nib_to_sitk |
11
|
|
|
from ....data.subject import Subject |
12
|
|
|
from ....data.image import ScalarImage |
13
|
|
|
from ....typing import TypeTripletInt, TypeTripletFloat |
14
|
|
|
from ... import SpatialTransform |
15
|
|
|
from .. import RandomTransform |
16
|
|
|
|
17
|
|
|
|
18
|
|
|
SPLINE_ORDER = 3 |
19
|
|
|
|
20
|
|
|
|
21
|
|
|
class RandomElasticDeformation(RandomTransform, SpatialTransform): |
22
|
|
|
r"""Apply dense random elastic deformation. |
23
|
|
|
|
24
|
|
|
A random displacement is assigned to a coarse grid of control points around |
25
|
|
|
and inside the image. The displacement at each voxel is interpolated from |
26
|
|
|
the coarse grid using cubic B-splines. |
27
|
|
|
|
28
|
|
|
The `'Deformable Registration' <https://www.sciencedirect.com/topics/computer-science/deformable-registration>`_ |
29
|
|
|
topic on ScienceDirect contains useful articles explaining interpolation of |
30
|
|
|
displacement fields using cubic B-splines. |
31
|
|
|
|
32
|
|
|
.. warning:: This transform is slow as it requires expensive computations. |
33
|
|
|
If your images are large you might want to use |
34
|
|
|
:class:`~torchio.transforms.RandomAffine` instead. |
35
|
|
|
|
36
|
|
|
Args: |
37
|
|
|
num_control_points: Number of control points along each dimension of |
38
|
|
|
the coarse grid :math:`(n_x, n_y, n_z)`. |
39
|
|
|
If a single value :math:`n` is passed, |
40
|
|
|
then :math:`n_x = n_y = n_z = n`. |
41
|
|
|
Smaller numbers generate smoother deformations. |
42
|
|
|
The minimum number of control points is ``4`` as this transform |
43
|
|
|
uses cubic B-splines to interpolate displacement. |
44
|
|
|
max_displacement: Maximum displacement along each dimension at each |
45
|
|
|
control point :math:`(D_x, D_y, D_z)`. |
46
|
|
|
The displacement along dimension :math:`i` at each control point is |
47
|
|
|
:math:`d_i \sim \mathcal{U}(0, D_i)`. |
48
|
|
|
If a single value :math:`D` is passed, |
49
|
|
|
then :math:`D_x = D_y = D_z = D`. |
50
|
|
|
Note that the total maximum displacement would actually be |
51
|
|
|
:math:`D_{max} = \sqrt{D_x^2 + D_y^2 + D_z^2}`. |
52
|
|
|
locked_borders: If ``0``, all displacement vectors are kept. |
53
|
|
|
If ``1``, displacement of control points at the |
54
|
|
|
border of the coarse grid will be set to ``0``. |
55
|
|
|
If ``2``, displacement of control points at the border of the image |
56
|
|
|
(red dots in the image below) will also be set to ``0``. |
57
|
|
|
image_interpolation: See :ref:`Interpolation`. |
58
|
|
|
Note that this is the interpolation used to compute voxel |
59
|
|
|
intensities when resampling using the dense displacement field. |
60
|
|
|
The value of the dense displacement at each voxel is always |
61
|
|
|
interpolated with cubic B-splines from the values at the control |
62
|
|
|
points of the coarse grid. |
63
|
|
|
**kwargs: See :class:`~torchio.transforms.Transform` for additional |
64
|
|
|
keyword arguments. |
65
|
|
|
|
66
|
|
|
`This gist <https://gist.github.com/fepegar/b723d15de620cd2a3a4dbd71e491b59d>`_ |
67
|
|
|
can also be used to better understand the meaning of the parameters. |
68
|
|
|
|
69
|
|
|
This is an example from the |
70
|
|
|
`3D Slicer registration FAQ <https://www.slicer.org/wiki/Documentation/4.10/FAQ/Registration#What.27s_the_BSpline_Grid_Size.3F>`_. |
71
|
|
|
|
72
|
|
|
.. image:: https://www.slicer.org/w/img_auth.php/6/6f/RegLib_BSplineGridModel.png |
73
|
|
|
:alt: B-spline example from 3D Slicer documentation |
74
|
|
|
|
75
|
|
|
To generate a similar grid of control points with TorchIO, |
76
|
|
|
the transform can be instantiated as follows:: |
77
|
|
|
|
78
|
|
|
>>> from torchio import RandomElasticDeformation |
79
|
|
|
>>> transform = RandomElasticDeformation( |
80
|
|
|
... num_control_points=(7, 7, 7), # or just 7 |
81
|
|
|
... locked_borders=2, |
82
|
|
|
... ) |
83
|
|
|
|
84
|
|
|
Note that control points outside the image bounds are not showed in the |
85
|
|
|
example image (they would also be red as we set :attr:`locked_borders` |
86
|
|
|
to ``2``). |
87
|
|
|
|
88
|
|
|
.. warning:: Image folding may occur if the maximum displacement is larger |
89
|
|
|
than half the coarse grid spacing. The grid spacing can be computed |
90
|
|
|
using the image bounds in physical space [#]_ and the number of control |
91
|
|
|
points:: |
92
|
|
|
|
93
|
|
|
>>> import numpy as np |
94
|
|
|
>>> import torchio as tio |
95
|
|
|
>>> image = tio.datasets.Slicer().MRHead.as_sitk() |
96
|
|
|
>>> image.GetSize() # in voxels |
97
|
|
|
(256, 256, 130) |
98
|
|
|
>>> image.GetSpacing() # in mm |
99
|
|
|
(1.0, 1.0, 1.2999954223632812) |
100
|
|
|
>>> bounds = np.array(image.GetSize()) * np.array(image.GetSpacing()) |
101
|
|
|
>>> bounds # mm |
102
|
|
|
array([256. , 256. , 168.99940491]) |
103
|
|
|
>>> num_control_points = np.array((7, 7, 6)) |
104
|
|
|
>>> grid_spacing = bounds / (num_control_points - 2) |
105
|
|
|
>>> grid_spacing |
106
|
|
|
array([51.2 , 51.2 , 42.24985123]) |
107
|
|
|
>>> potential_folding = grid_spacing / 2 |
108
|
|
|
>>> potential_folding # mm |
109
|
|
|
array([25.6 , 25.6 , 21.12492561]) |
110
|
|
|
|
111
|
|
|
Using a :attr:`max_displacement` larger than the computed |
112
|
|
|
:attr:`potential_folding` will raise a :class:`RuntimeWarning`. |
113
|
|
|
|
114
|
|
|
.. [#] Technically, :math:`2 \epsilon` should be added to the |
115
|
|
|
image bounds, where :math:`\epsilon = 2^{-3}` `according to ITK |
116
|
|
|
source code <https://github.com/InsightSoftwareConsortium/ITK/blob/633f84548311600845d54ab2463d3412194690a8/Modules/Core/Transform/include/itkBSplineTransformInitializer.hxx#L116-L138>`_. |
117
|
|
|
""" # noqa: E501 |
118
|
|
|
|
119
|
|
|
def __init__( |
120
|
|
|
self, |
121
|
|
|
num_control_points: Union[int, Tuple[int, int, int]] = 7, |
122
|
|
|
max_displacement: Union[float, Tuple[float, float, float]] = 7.5, |
123
|
|
|
locked_borders: int = 2, |
124
|
|
|
image_interpolation: str = 'linear', |
125
|
|
|
**kwargs |
126
|
|
|
): |
127
|
|
|
super().__init__(**kwargs) |
128
|
|
|
self._bspline_transformation = None |
129
|
|
|
self.num_control_points = to_tuple(num_control_points, length=3) |
130
|
|
|
_parse_num_control_points(self.num_control_points) |
131
|
|
|
self.max_displacement = to_tuple(max_displacement, length=3) |
132
|
|
|
_parse_max_displacement(self.max_displacement) |
133
|
|
|
self.num_locked_borders = locked_borders |
134
|
|
|
if locked_borders not in (0, 1, 2): |
135
|
|
|
raise ValueError('locked_borders must be 0, 1, or 2') |
136
|
|
|
if locked_borders == 2 and 4 in self.num_control_points: |
137
|
|
|
message = ( |
138
|
|
|
'Setting locked_borders to 2 and using less than 5 control' |
139
|
|
|
'points results in an identity transform. Lock fewer borders' |
140
|
|
|
' or use more control points.' |
141
|
|
|
) |
142
|
|
|
raise ValueError(message) |
143
|
|
|
self.image_interpolation = self.parse_interpolation( |
144
|
|
|
image_interpolation) |
145
|
|
|
|
146
|
|
|
@staticmethod |
147
|
|
|
def get_params( |
148
|
|
|
num_control_points: TypeTripletInt, |
149
|
|
|
max_displacement: Tuple[float, float, float], |
150
|
|
|
num_locked_borders: int, |
151
|
|
|
) -> np.ndarray: |
152
|
|
|
grid_shape = num_control_points |
153
|
|
|
num_dimensions = 3 |
154
|
|
|
coarse_field = torch.rand(*grid_shape, num_dimensions) # [0, 1) |
155
|
|
|
coarse_field -= 0.5 # [-0.5, 0.5) |
156
|
|
|
coarse_field *= 2 # [-1, 1] |
157
|
|
|
for dimension in range(3): |
158
|
|
|
# [-max_displacement, max_displacement) |
159
|
|
|
coarse_field[..., dimension] *= max_displacement[dimension] |
160
|
|
|
|
161
|
|
|
# Set displacement to 0 at the borders |
162
|
|
|
for i in range(num_locked_borders): |
163
|
|
|
coarse_field[i, :] = 0 |
164
|
|
|
coarse_field[-1 - i, :] = 0 |
165
|
|
|
coarse_field[:, i] = 0 |
166
|
|
|
coarse_field[:, -1 - i] = 0 |
167
|
|
|
|
168
|
|
|
return coarse_field.numpy() |
169
|
|
|
|
170
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
171
|
|
|
subject.check_consistent_spatial_shape() |
172
|
|
|
control_points = self.get_params( |
173
|
|
|
self.num_control_points, |
174
|
|
|
self.max_displacement, |
175
|
|
|
self.num_locked_borders, |
176
|
|
|
) |
177
|
|
|
|
178
|
|
|
arguments = { |
179
|
|
|
'control_points': control_points, |
180
|
|
|
'max_displacement': self.max_displacement, |
181
|
|
|
'image_interpolation': self.image_interpolation, |
182
|
|
|
} |
183
|
|
|
|
184
|
|
|
transform = ElasticDeformation(**self.add_include_exclude(arguments)) |
185
|
|
|
transformed = transform(subject) |
186
|
|
|
return transformed |
187
|
|
|
|
188
|
|
|
|
189
|
|
|
class ElasticDeformation(SpatialTransform): |
190
|
|
|
r"""Apply dense elastic deformation. |
191
|
|
|
|
192
|
|
|
Args: |
193
|
|
|
control_points: |
194
|
|
|
max_displacement: |
195
|
|
|
image_interpolation: See :ref:`Interpolation`. |
196
|
|
|
**kwargs: See :class:`~torchio.transforms.Transform` for additional |
197
|
|
|
keyword arguments. |
198
|
|
|
""" |
199
|
|
|
|
200
|
|
|
def __init__( |
201
|
|
|
self, |
202
|
|
|
control_points: np.ndarray, |
203
|
|
|
max_displacement: TypeTripletFloat, |
204
|
|
|
image_interpolation: str = 'linear', |
205
|
|
|
**kwargs |
206
|
|
|
): |
207
|
|
|
super().__init__(**kwargs) |
208
|
|
|
self.control_points = control_points |
209
|
|
|
self.max_displacement = max_displacement |
210
|
|
|
self.image_interpolation = self.parse_interpolation( |
211
|
|
|
image_interpolation) |
212
|
|
|
self.invert_transform = False |
213
|
|
|
self.args_names = ( |
214
|
|
|
'control_points', |
215
|
|
|
'image_interpolation', |
216
|
|
|
'max_displacement', |
217
|
|
|
) |
218
|
|
|
|
219
|
|
|
def get_bspline_transform( |
220
|
|
|
self, |
221
|
|
|
image: sitk.Image, |
222
|
|
|
) -> sitk.BSplineTransformInitializer: |
223
|
|
|
control_points = self.control_points.copy() |
224
|
|
|
if self.invert_transform: |
225
|
|
|
control_points *= -1 |
226
|
|
|
is_2d = image.GetSize()[2] == 1 |
227
|
|
|
if is_2d: |
228
|
|
|
control_points[..., -1] = 0 # no displacement in IS axis |
229
|
|
|
num_control_points = control_points.shape[:-1] |
230
|
|
|
mesh_shape = [n - SPLINE_ORDER for n in num_control_points] |
231
|
|
|
bspline_transform = sitk.BSplineTransformInitializer(image, mesh_shape) |
232
|
|
|
parameters = control_points.flatten(order='F').tolist() |
233
|
|
|
bspline_transform.SetParameters(parameters) |
234
|
|
|
return bspline_transform |
235
|
|
|
|
236
|
|
|
@staticmethod |
237
|
|
|
def parse_free_form_transform( |
238
|
|
|
transform: sitk.Transform, |
239
|
|
|
max_displacement: Sequence[TypeTripletInt], |
240
|
|
|
) -> None: |
241
|
|
|
"""Issue a warning is possible folding is detected.""" |
242
|
|
|
coefficient_images = transform.GetCoefficientImages() |
243
|
|
|
grid_spacing = coefficient_images[0].GetSpacing() |
244
|
|
|
conflicts = np.array(max_displacement) > np.array(grid_spacing) / 2 |
245
|
|
|
if np.any(conflicts): |
246
|
|
|
where, = np.where(conflicts) |
247
|
|
|
message = ( |
248
|
|
|
'The maximum displacement is larger than the coarse grid' |
249
|
|
|
f' spacing for dimensions: {where.tolist()}, so folding may' |
250
|
|
|
' occur. Choose fewer control points or a smaller' |
251
|
|
|
' maximum displacement' |
252
|
|
|
) |
253
|
|
|
warnings.warn(message, RuntimeWarning) |
254
|
|
|
|
255
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
256
|
|
|
no_displacement = not any(self.max_displacement) |
257
|
|
|
if no_displacement: |
258
|
|
|
return subject |
259
|
|
|
subject.check_consistent_spatial_shape() |
260
|
|
|
for image in self.get_images(subject): |
261
|
|
|
if not isinstance(image, ScalarImage): |
262
|
|
|
interpolation = 'nearest' |
263
|
|
|
else: |
264
|
|
|
interpolation = self.image_interpolation |
265
|
|
|
transformed = self.apply_bspline_transform( |
266
|
|
|
image.data, |
267
|
|
|
image.affine, |
268
|
|
|
interpolation, |
269
|
|
|
) |
270
|
|
|
image.set_data(transformed) |
271
|
|
|
return subject |
272
|
|
|
|
273
|
|
|
def apply_bspline_transform( |
274
|
|
|
self, |
275
|
|
|
tensor: torch.Tensor, |
276
|
|
|
affine: np.ndarray, |
277
|
|
|
interpolation: str, |
278
|
|
|
) -> torch.Tensor: |
279
|
|
|
assert tensor.dim() == 4 |
280
|
|
|
results = [] |
281
|
|
|
for component in tensor: |
282
|
|
|
image = nib_to_sitk(component[np.newaxis], affine, force_3d=True) |
283
|
|
|
floating = reference = image |
284
|
|
|
bspline_transform = self.get_bspline_transform(image) |
285
|
|
|
self.parse_free_form_transform( |
286
|
|
|
bspline_transform, |
287
|
|
|
self.max_displacement, |
288
|
|
|
) |
289
|
|
|
interpolator = self.get_sitk_interpolator(interpolation) |
290
|
|
|
resampler = sitk.ResampleImageFilter() |
291
|
|
|
resampler.SetReferenceImage(reference) |
292
|
|
|
resampler.SetTransform(bspline_transform) |
293
|
|
|
resampler.SetInterpolator(interpolator) |
294
|
|
|
resampler.SetDefaultPixelValue(component.min().item()) |
295
|
|
|
resampler.SetOutputPixelType(sitk.sitkFloat32) |
296
|
|
|
resampled = resampler.Execute(floating) |
297
|
|
|
result, _ = self.sitk_to_nib(resampled) |
298
|
|
|
results.append(torch.as_tensor(result)) |
299
|
|
|
tensor = torch.cat(results) |
300
|
|
|
return tensor |
301
|
|
|
|
302
|
|
|
|
303
|
|
|
def _parse_num_control_points( |
304
|
|
|
num_control_points: TypeTripletInt, |
305
|
|
|
) -> None: |
306
|
|
|
for axis, number in enumerate(num_control_points): |
307
|
|
|
if not isinstance(number, int) or number < 4: |
308
|
|
|
message = ( |
309
|
|
|
f'The number of control points for axis {axis} must be' |
310
|
|
|
f' an integer greater than 3, not {number}' |
311
|
|
|
) |
312
|
|
|
raise ValueError(message) |
313
|
|
|
|
314
|
|
|
|
315
|
|
|
def _parse_max_displacement( |
316
|
|
|
max_displacement: Tuple[float, float, float], |
317
|
|
|
) -> None: |
318
|
|
|
for axis, number in enumerate(max_displacement): |
319
|
|
|
if not isinstance(number, Number) or number < 0: |
320
|
|
|
message = ( |
321
|
|
|
'The maximum displacement at each control point' |
322
|
|
|
f' for axis {axis} must be' |
323
|
|
|
f' a number greater or equal to 0, not {number}' |
324
|
|
|
) |
325
|
|
|
raise ValueError(message) |
326
|
|
|
|