Passed
Push — main ( 959623...57b877 )
by Fernando
02:37 queued 01:03
created

torchio.transforms.preprocessing.spatial.resample   B

Complexity

Total Complexity 47

Size/Duplication

Total Lines 432
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 47
eloc 249
dl 0
loc 432
rs 8.64
c 0
b 0
f 0

13 Methods

Rating   Name   Duplication   Size   Complexity  
A Resample._get_downsampling_factor() 0 22 1
D Resample.apply_transform() 0 62 13
A Resample._smooth() 0 22 1
A Resample.get_reference_image() 0 32 1
A Resample._parse_spacing() 0 17 5
A Resample.__init__() 0 28 1
A Resample.check_affine_key_presence() 0 10 3
A Resample._get_resampler() 0 17 1
A Resample._get_sigmas() 0 19 1
D Resample._set_resampler_reference() 0 63 13
A Resample._set_resampler_from_spacing() 0 7 1
A Resample.check_affine() 0 16 5
A Resample._set_resampler_from_shape_affine() 0 6 1

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.preprocessing.spatial.resample often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from __future__ import annotations
2
3
from collections.abc import Iterable
4
from collections.abc import Sized
5
from numbers import Number
6
from pathlib import Path
7
from typing import Union
8
9
import numpy as np
10
import SimpleITK as sitk
11
import torch
12
13
from ....data.image import Image
14
from ....data.image import ScalarImage
15
from ....data.io import get_sitk_metadata_from_ras_affine
16
from ....data.io import sitk_to_nib
17
from ....data.subject import Subject
18
from ....types import TypePath
19
from ....types import TypeTripletFloat
20
from ...spatial_transform import SpatialTransform
21
22
TypeSpacing = Union[float, tuple[float, float, float]]
23
TypeTarget = Union[TypeSpacing, str, Path, Image, None]
24
ONE_MILLIMITER_ISOTROPIC = 1
25
26
27
class Resample(SpatialTransform):
28
    """Resample image to a different physical space.
29
30
    This is a powerful transform that can be used to change the image shape
31
    or spatial metadata, or to apply a spatial transformation.
32
33
    Args:
34
        target: Argument to define the output space. Can be one of:
35
36
            - Output spacing :math:`(s_w, s_h, s_d)`, in mm. If only one value
37
              :math:`s` is specified, then :math:`s_w = s_h = s_d = s`.
38
39
            - Path to an image that will be used as reference.
40
41
            - Instance of :class:`~torchio.Image`.
42
43
            - Name of an image key in the subject.
44
45
            - Tuple ``(spatial_shape, affine)`` defining the output space.
46
47
        pre_affine_name: Name of the *image key* (not subject key) storing an
48
            affine matrix that will be applied to the image header before
49
            resampling. If ``None``, the image is resampled with an identity
50
            transform. See usage in the example below.
51
        image_interpolation: See :ref:`Interpolation`.
52
        label_interpolation: See :ref:`Interpolation`.
53
        scalars_only: Apply only to instances of :class:`~torchio.ScalarImage`.
54
            Used internally by :class:`~torchio.transforms.RandomAnisotropy`.
55
        antialias: If ``True``, apply Gaussian smoothing before
56
            downsampling along any dimension that will be downsampled. For example,
57
            if the input image has spacing (0.5, 0.5, 4) and the target
58
            spacing is (1, 1, 1), the image will be smoothed along the first two
59
            dimensions before resampling. Label maps are not smoothed.
60
            The standard deviations of the Gaussian kernels are computed according to
61
            the method described in Cardoso et al.,
62
            `Scale factor point spread function matching: beyond aliasing in image
63
            resampling
64
            <https://link.springer.com/chapter/10.1007/978-3-319-24571-3_81>`_,
65
            MICCAI 2015.
66
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
67
            keyword arguments.
68
69
    Example:
70
        >>> import torch
71
        >>> import torchio as tio
72
        >>> transform = tio.Resample()                      # resample all images to 1mm isotropic
73
        >>> transform = tio.Resample(2)                     # resample all images to 2mm isotropic
74
        >>> transform = tio.Resample('t1')                  # resample all images to 't1' image space
75
        >>> # Example: using a precomputed transform to MNI space
76
        >>> ref_path = tio.datasets.Colin27().t1.path  # this image is in the MNI space, so we can use it as reference/target
77
        >>> affine_matrix = tio.io.read_matrix('transform_to_mni.txt')  # from a NiftyReg registration. Would also work with e.g. .tfm from SimpleITK
78
        >>> image = tio.ScalarImage(tensor=torch.rand(1, 256, 256, 180), to_mni=affine_matrix)  # 'to_mni' is an arbitrary name
79
        >>> transform = tio.Resample(colin.t1.path, pre_affine_name='to_mni')  # nearest neighbor interpolation is used for label maps
80
        >>> transformed = transform(image)  # "image" is now in the MNI space
81
82
    .. note::
83
        The ``antialias`` option is recommended when large (e.g. > 2×) downsampling
84
        factors are expected, particularly for offline (before training) preprocessing,
85
        when run times are not a concern.
86
87
    .. plot::
88
89
        import torchio as tio
90
        subject = tio.datasets.FPG()
91
        subject.remove_image('seg')
92
        resample = tio.Resample(8)
93
        t1_resampled = resample(subject.t1)
94
        subject.add_image(t1_resampled, 'Downsampled')
95
        subject.plot()
96
    """
97
98
    def __init__(
99
        self,
100
        target: TypeTarget = ONE_MILLIMITER_ISOTROPIC,
101
        image_interpolation: str = 'linear',
102
        label_interpolation: str = 'nearest',
103
        pre_affine_name: str | None = None,
104
        scalars_only: bool = False,
105
        antialias: bool = False,
106
        **kwargs,
107
    ):
108
        super().__init__(**kwargs)
109
        self.target = target
110
        self.image_interpolation = self.parse_interpolation(
111
            image_interpolation,
112
        )
113
        self.label_interpolation = self.parse_interpolation(
114
            label_interpolation,
115
        )
116
        self.pre_affine_name = pre_affine_name
117
        self.scalars_only = scalars_only
118
        self.antialias = antialias
119
        self.args_names = [
120
            'target',
121
            'image_interpolation',
122
            'label_interpolation',
123
            'pre_affine_name',
124
            'scalars_only',
125
            'antialias',
126
        ]
127
128
    @staticmethod
129
    def _parse_spacing(spacing: TypeSpacing) -> tuple[float, float, float]:
130
        result: Iterable
131
        if isinstance(spacing, Iterable) and len(spacing) == 3:
132
            result = spacing
133
        elif isinstance(spacing, Number):
134
            result = 3 * (spacing,)
135
        else:
136
            message = (
137
                'Target must be a string, a positive number'
138
                f' or a sequence of positive numbers, not {type(spacing)}'
139
            )
140
            raise ValueError(message)
141
        if np.any(np.array(spacing) <= 0):
142
            message = f'Spacing must be strictly positive, not "{spacing}"'
143
            raise ValueError(message)
144
        return result
145
146
    @staticmethod
147
    def check_affine(affine_name: str, image: Image):
148
        if not isinstance(affine_name, str):
149
            message = f'Affine name argument must be a string, not {type(affine_name)}'
150
            raise TypeError(message)
151
        if affine_name in image:
152
            matrix = image[affine_name]
153
            if not isinstance(matrix, (np.ndarray, torch.Tensor)):
154
                message = (
155
                    'The affine matrix must be a NumPy array or PyTorch'
156
                    f' tensor, not {type(matrix)}'
157
                )
158
                raise TypeError(message)
159
            if matrix.shape != (4, 4):
160
                message = f'The affine matrix shape must be (4, 4), not {matrix.shape}'
161
                raise ValueError(message)
162
163
    @staticmethod
164
    def check_affine_key_presence(affine_name: str, subject: Subject):
165
        for image in subject.get_images(intensity_only=False):
166
            if affine_name in image:
167
                return
168
        message = (
169
            f'An affine name was given ("{affine_name}"), but it was not found'
170
            ' in any image in the subject'
171
        )
172
        raise ValueError(message)
173
174
    def apply_transform(self, subject: Subject) -> Subject:
175
        use_pre_affine = self.pre_affine_name is not None
176
        if use_pre_affine:
177
            assert self.pre_affine_name is not None  # for mypy
178
            self.check_affine_key_presence(self.pre_affine_name, subject)
179
180
        for image in self.get_images(subject):
181
            # If the current image is the reference, don't resample it
182
            if self.target is image:
183
                continue
184
185
            # If the target is not a string, or is not an image in the subject,
186
            # do nothing
187
            try:
188
                target_image = subject[self.target]
189
                if target_image is image:
190
                    continue
191
            except (KeyError, TypeError, RuntimeError):
192
                pass
193
194
            # Choose interpolation
195
            if not isinstance(image, ScalarImage):
196
                if self.scalars_only:
197
                    continue
198
                interpolation = self.label_interpolation
199
            else:
200
                interpolation = self.image_interpolation
201
            interpolator = self.get_sitk_interpolator(interpolation)
202
203
            # Apply given affine matrix if found in image
204
            if use_pre_affine and self.pre_affine_name in image:
205
                assert self.pre_affine_name is not None  # for mypy
206
                self.check_affine(self.pre_affine_name, image)
207
                matrix = image[self.pre_affine_name]
208
                if isinstance(matrix, torch.Tensor):
209
                    matrix = matrix.numpy()
210
                image.affine = matrix @ image.affine
211
212
            floating_sitk = image.as_sitk(force_3d=True)
213
214
            resampler = self._get_resampler(
215
                interpolator,
216
                floating_sitk,
217
                subject,
218
                self.target,
219
            )
220
            if self.antialias and isinstance(image, ScalarImage):
221
                downsampling_factor = self._get_downsampling_factor(
222
                    floating_sitk,
223
                    resampler,
224
                )
225
                sigmas = self._get_sigmas(
226
                    downsampling_factor,
227
                    floating_sitk.GetSpacing(),
228
                )
229
                floating_sitk = self._smooth(floating_sitk, sigmas)
230
            resampled = resampler.Execute(floating_sitk)
231
232
            array, affine = sitk_to_nib(resampled)
233
            image.set_data(torch.as_tensor(array))
234
            image.affine = affine
235
        return subject
236
237
    @staticmethod
238
    def _smooth(
239
        image: sitk.Image,
240
        sigmas: np.ndarray,
241
        epsilon: float = 1e-9,
242
    ) -> sitk.Image:
243
        """Smooth the image with a Gaussian kernel.
244
245
        Args:
246
            image: Image to be smoothed.
247
            sigmas: Standard deviations of the Gaussian kernel for each
248
                dimension. If a value is NaN, no smoothing is applied in that
249
                dimension.
250
            epsilon: Small value to replace NaN values in sigmas, to avoid
251
                division-by-zero errors.
252
        """
253
254
        sigmas[np.isnan(sigmas)] = epsilon  # no smoothing in that dimension
255
        gaussian = sitk.SmoothingRecursiveGaussianImageFilter()
256
        gaussian.SetSigma(sigmas.tolist())
257
        smoothed = gaussian.Execute(image)
258
        return smoothed
259
260
    @staticmethod
261
    def _get_downsampling_factor(
262
        floating: sitk.Image,
263
        resampler: sitk.ResampleImageFilter,
264
    ) -> np.ndarray:
265
        """Get the downsampling factor for each dimension.
266
267
        The downsampling factor is the ratio between the output spacing and
268
        the input spacing. If the output spacing is smaller than the input
269
        spacing, the factor is set to NaN, meaning downsampling is not applied
270
        in that dimension.
271
272
        Args:
273
            floating: The input image to be resampled.
274
            resampler: The resampler that will be used to resample the image.
275
        """
276
        input_spacing = np.array(floating.GetSpacing())
277
        output_spacing = np.array(resampler.GetOutputSpacing())
278
        factors = output_spacing / input_spacing
279
        no_downsampling = factors <= 1
280
        factors[no_downsampling] = np.nan
281
        return factors
282
283
    def _get_resampler(
284
        self,
285
        interpolator: int,
286
        floating: sitk.Image,
287
        subject: Subject,
288
        target: TypeTarget,
289
    ) -> sitk.ResampleImageFilter:
290
        """Instantiate a SimpleITK resampler."""
291
        resampler = sitk.ResampleImageFilter()
292
        resampler.SetInterpolator(interpolator)
293
        self._set_resampler_reference(
294
            resampler,
295
            target,  # type: ignore[arg-type]
296
            floating,
297
            subject,
298
        )
299
        return resampler
300
301
    def _set_resampler_reference(
302
        self,
303
        resampler: sitk.ResampleImageFilter,
304
        target: TypeSpacing | TypePath | Image,
305
        floating_sitk,
306
        subject,
307
    ):
308
        # Target can be:
309
        # 1) An instance of torchio.Image
310
        # 2) An instance of pathlib.Path
311
        # 3) A string, which could be a path or an image in subject
312
        # 4) A number or sequence of numbers for spacing
313
        # 5) A tuple of shape, affine
314
        # The fourth case is the different one
315
        if isinstance(target, (str, Path, Image)):
316
            if isinstance(target, Image):
317
                # It's a TorchIO image
318
                image = target
319
            elif Path(target).is_file():
320
                # It's an existing file
321
                path = target
322
                image = ScalarImage(path)
323
            else:  # assume it's the name of an image in the subject
324
                try:
325
                    image = subject[target]
326
                except KeyError as error:
327
                    message = (
328
                        f'Image name "{target}" not found in subject.'
329
                        f' If "{target}" is a path, it does not exist or'
330
                        ' permission has been denied'
331
                    )
332
                    raise ValueError(message) from error
333
            self._set_resampler_from_shape_affine(
334
                resampler,
335
                image.spatial_shape,
336
                image.affine,
337
            )
338
        elif isinstance(target, Number):  # one number for target was passed
339
            self._set_resampler_from_spacing(resampler, target, floating_sitk)
340
        elif isinstance(target, Iterable) and len(target) == 2:
341
            assert not isinstance(target, str)  # for mypy
342
            shape, affine = target
343
            if not (isinstance(shape, Sized) and len(shape) == 3):
344
                message = (
345
                    'Target shape must be a sequence of three integers, but'
346
                    f' "{shape}" was passed'
347
                )
348
                raise RuntimeError(message)
349
            if not affine.shape == (4, 4):
350
                message = (
351
                    'Target affine must have shape (4, 4) but the following'
352
                    f' was passed:\n{shape}'
353
                )
354
                raise RuntimeError(message)
355
            self._set_resampler_from_shape_affine(
356
                resampler,
357
                shape,
358
                affine,
359
            )
360
        elif isinstance(target, Iterable) and len(target) == 3:
361
            self._set_resampler_from_spacing(resampler, target, floating_sitk)
362
        else:
363
            raise RuntimeError(f'Target not understood: "{target}"')
364
365
    def _set_resampler_from_shape_affine(self, resampler, shape, affine):
366
        origin, spacing, direction = get_sitk_metadata_from_ras_affine(affine)
367
        resampler.SetOutputDirection(direction)
368
        resampler.SetOutputOrigin(origin)
369
        resampler.SetOutputSpacing(spacing)
370
        resampler.SetSize(shape)
371
372
    def _set_resampler_from_spacing(self, resampler, target, floating_sitk):
373
        target_spacing = self._parse_spacing(target)
374
        reference_image = self.get_reference_image(
375
            floating_sitk,
376
            target_spacing,
377
        )
378
        resampler.SetReferenceImage(reference_image)
379
380
    @staticmethod
381
    def get_reference_image(
382
        floating_sitk: sitk.Image,
383
        spacing: TypeTripletFloat,
384
    ) -> sitk.Image:
385
        old_spacing = np.array(floating_sitk.GetSpacing(), dtype=float)
386
        new_spacing = np.array(spacing, dtype=float)
387
        old_size = np.array(floating_sitk.GetSize())
388
        old_last_index = old_size - 1
389
        old_last_index_lps = np.array(
390
            floating_sitk.TransformIndexToPhysicalPoint(old_last_index.tolist()),
391
            dtype=float,
392
        )
393
        old_origin_lps = np.array(floating_sitk.GetOrigin(), dtype=float)
394
        center_lps = (old_last_index_lps + old_origin_lps) / 2
395
        # We use floor to avoid extrapolation by keeping the extent of the
396
        # new image the same or smaller than the original.
397
        new_size = np.floor(old_size * old_spacing / new_spacing)
398
        # We keep singleton dimensions to avoid e.g. making 2D images 3D
399
        new_size[old_size == 1] = 1
400
        direction = np.asarray(floating_sitk.GetDirection(), dtype=float).reshape(3, 3)
401
        half_extent = (new_size - 1) / 2 * new_spacing
402
        new_origin_lps = (center_lps - direction @ half_extent).tolist()
403
        reference = sitk.Image(
404
            new_size.astype(int).tolist(),
405
            floating_sitk.GetPixelID(),
406
            floating_sitk.GetNumberOfComponentsPerPixel(),
407
        )
408
        reference.SetDirection(floating_sitk.GetDirection())
409
        reference.SetSpacing(new_spacing.tolist())
410
        reference.SetOrigin(new_origin_lps)
411
        return reference
412
413
    @staticmethod
414
    def _get_sigmas(downsampling_factor: np.ndarray, spacing: np.ndarray) -> np.ndarray:
415
        """Compute optimal standard deviation for Gaussian kernel.
416
417
        From Cardoso et al., `Scale factor point spread function matching:
418
        beyond aliasing in image resampling
419
        <https://link.springer.com/chapter/10.1007/978-3-319-24571-3_81>`_,
420
        MICCAI 2015.
421
422
        Args:
423
            downsampling_factor: Array with the downsampling factor for each
424
                dimension.
425
            spacing: Array with the spacing of the input image in mm.
426
        """
427
        k = downsampling_factor
428
        # Equation from top of page 678 of proceedings (4/9 in the PDF)
429
        variance = (k**2 - 1) * (2 * np.sqrt(2 * np.log(2))) ** (-2)
430
        sigma = spacing * np.sqrt(variance)
431
        return sigma
432