torchio.transforms.preprocessing.spatial.resample   B
last analyzed

Complexity

Total Complexity 47

Size/Duplication

Total Lines 435
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 47
eloc 249
dl 0
loc 435
rs 8.64
c 0
b 0
f 0

13 Methods

Rating   Name   Duplication   Size   Complexity  
A Resample._get_downsampling_factor() 0 22 1
D Resample.apply_transform() 0 62 13
A Resample._smooth() 0 22 1
A Resample.get_reference_image() 0 32 1
A Resample._parse_spacing() 0 17 5
A Resample.__init__() 0 28 1
A Resample.check_affine_key_presence() 0 10 3
A Resample._get_resampler() 0 17 1
A Resample._get_sigmas() 0 19 1
D Resample._set_resampler_reference() 0 63 13
A Resample._set_resampler_from_spacing() 0 7 1
A Resample.check_affine() 0 16 5
A Resample._set_resampler_from_shape_affine() 0 6 1

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.preprocessing.spatial.resample often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from __future__ import annotations
2
3
from collections.abc import Iterable
4
from collections.abc import Sized
5
from numbers import Number
6
from pathlib import Path
7
from typing import Union
8
9
import numpy as np
10
import SimpleITK as sitk
11
import torch
12
13
from ....data.image import Image
14
from ....data.image import ScalarImage
15
from ....data.io import get_sitk_metadata_from_ras_affine
16
from ....data.io import sitk_to_nib
17
from ....data.subject import Subject
18
from ....types import TypePath
19
from ....types import TypeTripletFloat
20
from ...spatial_transform import SpatialTransform
21
22
TypeSpacing = Union[float, tuple[float, float, float]]
23
TypeTarget = Union[TypeSpacing, str, Path, Image, None]
24
ONE_MILLIMITER_ISOTROPIC = 1
25
26
27
class Resample(SpatialTransform):
28
    """Resample image to a different physical space.
29
30
    This is a powerful transform that can be used to change the image shape
31
    or spatial metadata, or to apply a spatial transformation.
32
33
    Args:
34
        target: Argument to define the output space. Can be one of:
35
36
            - Output spacing :math:`(s_w, s_h, s_d)`, in mm. If only one value
37
              :math:`s` is specified, then :math:`s_w = s_h = s_d = s`.
38
39
            - Path to an image that will be used as reference.
40
41
            - Instance of :class:`~torchio.Image`.
42
43
            - Name of an image key in the subject.
44
45
            - Tuple ``(spatial_shape, affine)`` defining the output space.
46
47
        pre_affine_name: Name of the *image key* (not subject key) storing an
48
            affine matrix that will be applied to the image header before
49
            resampling. If ``None``, the image is resampled with an identity
50
            transform. See usage in the example below.
51
        image_interpolation: See :ref:`Interpolation`.
52
        label_interpolation: See :ref:`Interpolation`.
53
        scalars_only: Apply only to instances of :class:`~torchio.ScalarImage`.
54
            Used internally by :class:`~torchio.transforms.RandomAnisotropy`.
55
        antialias: If ``True``, apply Gaussian smoothing before
56
            downsampling along any dimension that will be downsampled. For example,
57
            if the input image has spacing (0.5, 0.5, 4) and the target
58
            spacing is (1, 1, 1), the image will be smoothed along the first two
59
            dimensions before resampling. Label maps are not smoothed.
60
            The standard deviations of the Gaussian kernels are computed according to
61
            the method described in Cardoso et al.,
62
            `Scale factor point spread function matching: beyond aliasing in image
63
            resampling
64
            <https://link.springer.com/chapter/10.1007/978-3-319-24571-3_81>`_,
65
            MICCAI 2015.
66
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
67
            keyword arguments.
68
69
    Example:
70
        >>> import torch
71
        >>> import torchio as tio
72
        >>> transform = tio.Resample()                      # resample all images to 1mm isotropic
73
        >>> transform = tio.Resample(2)                     # resample all images to 2mm isotropic
74
        >>> transform = tio.Resample('t1')                  # resample all images to 't1' image space
75
        >>> # Example: using a precomputed transform to MNI space
76
        >>> ref_path = tio.datasets.Colin27().t1.path  # this image is in the MNI space, so we can use it as reference/target
77
        >>> affine_matrix = tio.io.read_matrix('transform_to_mni.txt')  # from a NiftyReg registration. Would also work with e.g. .tfm from SimpleITK
78
        >>> image = tio.ScalarImage(tensor=torch.rand(1, 256, 256, 180), to_mni=affine_matrix)  # 'to_mni' is an arbitrary name
79
        >>> transform = tio.Resample(colin.t1.path, pre_affine_name='to_mni')  # nearest neighbor interpolation is used for label maps
80
        >>> transformed = transform(image)  # "image" is now in the MNI space
81
82
    .. note::
83
        The ``antialias`` option is recommended when large (e.g. > 2×) downsampling
84
        factors are expected, particularly for offline (before training) preprocessing,
85
        when run times are not a concern.
86
87
    .. plot::
88
89
        import torchio as tio
90
        subject = tio.datasets.FPG()
91
        subject.remove_image('seg')
92
        resample = tio.Resample(8)
93
        t1_resampled = resample(subject.t1)
94
        subject.add_image(t1_resampled, 'Antialias off')
95
        resample = tio.Resample(8, antialias=True)
96
        t1_resampled_antialias = resample(subject.t1)
97
        subject.add_image(t1_resampled_antialias, 'Antialias on')
98
        subject.plot()
99
    """
100
101
    def __init__(
102
        self,
103
        target: TypeTarget = ONE_MILLIMITER_ISOTROPIC,
104
        image_interpolation: str = 'linear',
105
        label_interpolation: str = 'nearest',
106
        pre_affine_name: str | None = None,
107
        scalars_only: bool = False,
108
        antialias: bool = False,
109
        **kwargs,
110
    ):
111
        super().__init__(**kwargs)
112
        self.target = target
113
        self.image_interpolation = self.parse_interpolation(
114
            image_interpolation,
115
        )
116
        self.label_interpolation = self.parse_interpolation(
117
            label_interpolation,
118
        )
119
        self.pre_affine_name = pre_affine_name
120
        self.scalars_only = scalars_only
121
        self.antialias = antialias
122
        self.args_names = [
123
            'target',
124
            'image_interpolation',
125
            'label_interpolation',
126
            'pre_affine_name',
127
            'scalars_only',
128
            'antialias',
129
        ]
130
131
    @staticmethod
132
    def _parse_spacing(spacing: TypeSpacing) -> tuple[float, float, float]:
133
        result: Iterable
134
        if isinstance(spacing, Iterable) and len(spacing) == 3:
135
            result = spacing
136
        elif isinstance(spacing, Number):
137
            result = 3 * (spacing,)
138
        else:
139
            message = (
140
                'Target must be a string, a positive number'
141
                f' or a sequence of positive numbers, not {type(spacing)}'
142
            )
143
            raise ValueError(message)
144
        if np.any(np.array(spacing) <= 0):
145
            message = f'Spacing must be strictly positive, not "{spacing}"'
146
            raise ValueError(message)
147
        return result
148
149
    @staticmethod
150
    def check_affine(affine_name: str, image: Image):
151
        if not isinstance(affine_name, str):
152
            message = f'Affine name argument must be a string, not {type(affine_name)}'
153
            raise TypeError(message)
154
        if affine_name in image:
155
            matrix = image[affine_name]
156
            if not isinstance(matrix, (np.ndarray, torch.Tensor)):
157
                message = (
158
                    'The affine matrix must be a NumPy array or PyTorch'
159
                    f' tensor, not {type(matrix)}'
160
                )
161
                raise TypeError(message)
162
            if matrix.shape != (4, 4):
163
                message = f'The affine matrix shape must be (4, 4), not {matrix.shape}'
164
                raise ValueError(message)
165
166
    @staticmethod
167
    def check_affine_key_presence(affine_name: str, subject: Subject):
168
        for image in subject.get_images(intensity_only=False):
169
            if affine_name in image:
170
                return
171
        message = (
172
            f'An affine name was given ("{affine_name}"), but it was not found'
173
            ' in any image in the subject'
174
        )
175
        raise ValueError(message)
176
177
    def apply_transform(self, subject: Subject) -> Subject:
178
        use_pre_affine = self.pre_affine_name is not None
179
        if use_pre_affine:
180
            assert self.pre_affine_name is not None  # for mypy
181
            self.check_affine_key_presence(self.pre_affine_name, subject)
182
183
        for image in self.get_images(subject):
184
            # If the current image is the reference, don't resample it
185
            if self.target is image:
186
                continue
187
188
            # If the target is not a string, or is not an image in the subject,
189
            # do nothing
190
            try:
191
                target_image = subject[self.target]
192
                if target_image is image:
193
                    continue
194
            except (KeyError, TypeError, RuntimeError):
195
                pass
196
197
            # Choose interpolation
198
            if not isinstance(image, ScalarImage):
199
                if self.scalars_only:
200
                    continue
201
                interpolation = self.label_interpolation
202
            else:
203
                interpolation = self.image_interpolation
204
            interpolator = self.get_sitk_interpolator(interpolation)
205
206
            # Apply given affine matrix if found in image
207
            if use_pre_affine and self.pre_affine_name in image:
208
                assert self.pre_affine_name is not None  # for mypy
209
                self.check_affine(self.pre_affine_name, image)
210
                matrix = image[self.pre_affine_name]
211
                if isinstance(matrix, torch.Tensor):
212
                    matrix = matrix.numpy()
213
                image.affine = matrix @ image.affine
214
215
            floating_sitk = image.as_sitk(force_3d=True)
216
217
            resampler = self._get_resampler(
218
                interpolator,
219
                floating_sitk,
220
                subject,
221
                self.target,
222
            )
223
            if self.antialias and isinstance(image, ScalarImage):
224
                downsampling_factor = self._get_downsampling_factor(
225
                    floating_sitk,
226
                    resampler,
227
                )
228
                sigmas = self._get_sigmas(
229
                    downsampling_factor,
230
                    floating_sitk.GetSpacing(),
231
                )
232
                floating_sitk = self._smooth(floating_sitk, sigmas)
233
            resampled = resampler.Execute(floating_sitk)
234
235
            array, affine = sitk_to_nib(resampled)
236
            image.set_data(torch.as_tensor(array))
237
            image.affine = affine
238
        return subject
239
240
    @staticmethod
241
    def _smooth(
242
        image: sitk.Image,
243
        sigmas: np.ndarray,
244
        epsilon: float = 1e-9,
245
    ) -> sitk.Image:
246
        """Smooth the image with a Gaussian kernel.
247
248
        Args:
249
            image: Image to be smoothed.
250
            sigmas: Standard deviations of the Gaussian kernel for each
251
                dimension. If a value is NaN, no smoothing is applied in that
252
                dimension.
253
            epsilon: Small value to replace NaN values in sigmas, to avoid
254
                division-by-zero errors.
255
        """
256
257
        sigmas[np.isnan(sigmas)] = epsilon  # no smoothing in that dimension
258
        gaussian = sitk.SmoothingRecursiveGaussianImageFilter()
259
        gaussian.SetSigma(sigmas.tolist())
260
        smoothed = gaussian.Execute(image)
261
        return smoothed
262
263
    @staticmethod
264
    def _get_downsampling_factor(
265
        floating: sitk.Image,
266
        resampler: sitk.ResampleImageFilter,
267
    ) -> np.ndarray:
268
        """Get the downsampling factor for each dimension.
269
270
        The downsampling factor is the ratio between the output spacing and
271
        the input spacing. If the output spacing is smaller than the input
272
        spacing, the factor is set to NaN, meaning downsampling is not applied
273
        in that dimension.
274
275
        Args:
276
            floating: The input image to be resampled.
277
            resampler: The resampler that will be used to resample the image.
278
        """
279
        input_spacing = np.array(floating.GetSpacing())
280
        output_spacing = np.array(resampler.GetOutputSpacing())
281
        factors = output_spacing / input_spacing
282
        no_downsampling = factors <= 1
283
        factors[no_downsampling] = np.nan
284
        return factors
285
286
    def _get_resampler(
287
        self,
288
        interpolator: int,
289
        floating: sitk.Image,
290
        subject: Subject,
291
        target: TypeTarget,
292
    ) -> sitk.ResampleImageFilter:
293
        """Instantiate a SimpleITK resampler."""
294
        resampler = sitk.ResampleImageFilter()
295
        resampler.SetInterpolator(interpolator)
296
        self._set_resampler_reference(
297
            resampler,
298
            target,  # type: ignore[arg-type]
299
            floating,
300
            subject,
301
        )
302
        return resampler
303
304
    def _set_resampler_reference(
305
        self,
306
        resampler: sitk.ResampleImageFilter,
307
        target: TypeSpacing | TypePath | Image,
308
        floating_sitk,
309
        subject,
310
    ):
311
        # Target can be:
312
        # 1) An instance of torchio.Image
313
        # 2) An instance of pathlib.Path
314
        # 3) A string, which could be a path or an image in subject
315
        # 4) A number or sequence of numbers for spacing
316
        # 5) A tuple of shape, affine
317
        # The fourth case is the different one
318
        if isinstance(target, (str, Path, Image)):
319
            if isinstance(target, Image):
320
                # It's a TorchIO image
321
                image = target
322
            elif Path(target).is_file():
323
                # It's an existing file
324
                path = target
325
                image = ScalarImage(path)
326
            else:  # assume it's the name of an image in the subject
327
                try:
328
                    image = subject[target]
329
                except KeyError as error:
330
                    message = (
331
                        f'Image name "{target}" not found in subject.'
332
                        f' If "{target}" is a path, it does not exist or'
333
                        ' permission has been denied'
334
                    )
335
                    raise ValueError(message) from error
336
            self._set_resampler_from_shape_affine(
337
                resampler,
338
                image.spatial_shape,
339
                image.affine,
340
            )
341
        elif isinstance(target, Number):  # one number for target was passed
342
            self._set_resampler_from_spacing(resampler, target, floating_sitk)
343
        elif isinstance(target, Iterable) and len(target) == 2:
344
            assert not isinstance(target, str)  # for mypy
345
            shape, affine = target
346
            if not (isinstance(shape, Sized) and len(shape) == 3):
347
                message = (
348
                    'Target shape must be a sequence of three integers, but'
349
                    f' "{shape}" was passed'
350
                )
351
                raise RuntimeError(message)
352
            if not affine.shape == (4, 4):
353
                message = (
354
                    'Target affine must have shape (4, 4) but the following'
355
                    f' was passed:\n{shape}'
356
                )
357
                raise RuntimeError(message)
358
            self._set_resampler_from_shape_affine(
359
                resampler,
360
                shape,
361
                affine,
362
            )
363
        elif isinstance(target, Iterable) and len(target) == 3:
364
            self._set_resampler_from_spacing(resampler, target, floating_sitk)
365
        else:
366
            raise RuntimeError(f'Target not understood: "{target}"')
367
368
    def _set_resampler_from_shape_affine(self, resampler, shape, affine):
369
        origin, spacing, direction = get_sitk_metadata_from_ras_affine(affine)
370
        resampler.SetOutputDirection(direction)
371
        resampler.SetOutputOrigin(origin)
372
        resampler.SetOutputSpacing(spacing)
373
        resampler.SetSize(shape)
374
375
    def _set_resampler_from_spacing(self, resampler, target, floating_sitk):
376
        target_spacing = self._parse_spacing(target)
377
        reference_image = self.get_reference_image(
378
            floating_sitk,
379
            target_spacing,
380
        )
381
        resampler.SetReferenceImage(reference_image)
382
383
    @staticmethod
384
    def get_reference_image(
385
        floating_sitk: sitk.Image,
386
        spacing: TypeTripletFloat,
387
    ) -> sitk.Image:
388
        old_spacing = np.array(floating_sitk.GetSpacing(), dtype=float)
389
        new_spacing = np.array(spacing, dtype=float)
390
        old_size = np.array(floating_sitk.GetSize())
391
        old_last_index = old_size - 1
392
        old_last_index_lps = np.array(
393
            floating_sitk.TransformIndexToPhysicalPoint(old_last_index.tolist()),
394
            dtype=float,
395
        )
396
        old_origin_lps = np.array(floating_sitk.GetOrigin(), dtype=float)
397
        center_lps = (old_last_index_lps + old_origin_lps) / 2
398
        # We use floor to avoid extrapolation by keeping the extent of the
399
        # new image the same or smaller than the original.
400
        new_size = np.floor(old_size * old_spacing / new_spacing)
401
        # We keep singleton dimensions to avoid e.g. making 2D images 3D
402
        new_size[old_size == 1] = 1
403
        direction = np.asarray(floating_sitk.GetDirection(), dtype=float).reshape(3, 3)
404
        half_extent = (new_size - 1) / 2 * new_spacing
405
        new_origin_lps = (center_lps - direction @ half_extent).tolist()
406
        reference = sitk.Image(
407
            new_size.astype(int).tolist(),
408
            floating_sitk.GetPixelID(),
409
            floating_sitk.GetNumberOfComponentsPerPixel(),
410
        )
411
        reference.SetDirection(floating_sitk.GetDirection())
412
        reference.SetSpacing(new_spacing.tolist())
413
        reference.SetOrigin(new_origin_lps)
414
        return reference
415
416
    @staticmethod
417
    def _get_sigmas(downsampling_factor: np.ndarray, spacing: np.ndarray) -> np.ndarray:
418
        """Compute optimal standard deviation for Gaussian kernel.
419
420
        From Cardoso et al., `Scale factor point spread function matching:
421
        beyond aliasing in image resampling
422
        <https://link.springer.com/chapter/10.1007/978-3-319-24571-3_81>`_,
423
        MICCAI 2015.
424
425
        Args:
426
            downsampling_factor: Array with the downsampling factor for each
427
                dimension.
428
            spacing: Array with the spacing of the input image in mm.
429
        """
430
        k = downsampling_factor
431
        # Equation from top of page 678 of proceedings (4/9 in the PDF)
432
        variance = (k**2 - 1) * (2 * np.sqrt(2 * np.log(2))) ** (-2)
433
        sigma = spacing * np.sqrt(variance)
434
        return sigma
435