Passed
Pull Request — main (#1314)
by Fernando
01:27
created

torchio.transforms.preprocessing.spatial.resample   B

Complexity

Total Complexity 47

Size/Duplication

Total Lines 425
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 47
eloc 249
dl 0
loc 425
rs 8.64
c 0
b 0
f 0

13 Methods

Rating   Name   Duplication   Size   Complexity  
A Resample._get_downsampling_factor() 0 22 1
D Resample.apply_transform() 0 62 13
A Resample._smooth() 0 22 1
A Resample._parse_spacing() 0 17 5
A Resample.__init__() 0 28 1
A Resample.check_affine_key_presence() 0 10 3
A Resample._get_resampler() 0 17 1
D Resample._set_resampler_reference() 0 63 13
A Resample._set_resampler_from_spacing() 0 7 1
A Resample.check_affine() 0 16 5
A Resample._set_resampler_from_shape_affine() 0 6 1
A Resample.get_reference_image() 0 32 1
A Resample._get_sigmas() 0 19 1

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.preprocessing.spatial.resample often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from __future__ import annotations
2
3
from collections.abc import Iterable
4
from collections.abc import Sized
5
from numbers import Number
6
from pathlib import Path
7
from typing import Union
8
9
import numpy as np
10
import SimpleITK as sitk
11
import torch
12
13
from ....data.image import Image
14
from ....data.image import ScalarImage
15
from ....data.io import get_sitk_metadata_from_ras_affine
16
from ....data.io import sitk_to_nib
17
from ....data.subject import Subject
18
from ....types import TypePath
19
from ....types import TypeTripletFloat
20
from ...spatial_transform import SpatialTransform
21
22
TypeSpacing = Union[float, tuple[float, float, float]]
23
TypeTarget = Union[TypeSpacing, str, Path, Image, None]
24
ONE_MILLIMITER_ISOTROPIC = 1
25
26
27
class Resample(SpatialTransform):
28
    """Resample image to a different physical space.
29
30
    This is a powerful transform that can be used to change the image shape
31
    or spatial metadata, or to apply a spatial transformation.
32
33
    Args:
34
        target: Argument to define the output space. Can be one of:
35
36
            - Output spacing :math:`(s_w, s_h, s_d)`, in mm. If only one value
37
              :math:`s` is specified, then :math:`s_w = s_h = s_d = s`.
38
39
            - Path to an image that will be used as reference.
40
41
            - Instance of :class:`~torchio.Image`.
42
43
            - Name of an image key in the subject.
44
45
            - Tuple ``(spatial_shape, affine)`` defining the output space.
46
47
        pre_affine_name: Name of the *image key* (not subject key) storing an
48
            affine matrix that will be applied to the image header before
49
            resampling. If ``None``, the image is resampled with an identity
50
            transform. See usage in the example below.
51
        image_interpolation: See :ref:`Interpolation`.
52
        label_interpolation: See :ref:`Interpolation`.
53
        scalars_only: Apply only to instances of :class:`~torchio.ScalarImage`.
54
            Used internally by :class:`~torchio.transforms.RandomAnisotropy`.
55
        antialias: If ``True``, apply a Gaussian smoothing before
56
            downsampling, along any dimension that will be downsampled.
57
            This is useful to avoid aliasing artifacts when downsampling
58
            images. The standard deviation of the Gaussian kernel
59
            is computed according to the method described in Cardoso et al.,
60
            `Scale factor point spread function matching: beyond aliasing in
61
            image resampling
62
            <https://link.springer.com/chapter/10.1007/978-3-319-24571-3_81>`_,
63
            MICCAI 2015.
64
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
65
            keyword arguments.
66
67
    Example:
68
        >>> import torch
69
        >>> import torchio as tio
70
        >>> transform = tio.Resample(1)                     # resample all images to 1mm iso
71
        >>> transform = tio.Resample((2, 2, 2))             # resample all images to 2mm iso
72
        >>> transform = tio.Resample('t1')                  # resample all images to 't1' image space
73
        >>> # Example: using a precomputed transform to MNI space
74
        >>> ref_path = tio.datasets.Colin27().t1.path  # this image is in the MNI space, so we can use it as reference/target
75
        >>> affine_matrix = tio.io.read_matrix('transform_to_mni.txt')  # from a NiftyReg registration. Would also work with e.g. .tfm from SimpleITK
76
        >>> image = tio.ScalarImage(tensor=torch.rand(1, 256, 256, 180), to_mni=affine_matrix)  # 'to_mni' is an arbitrary name
77
        >>> transform = tio.Resample(colin.t1.path, pre_affine_name='to_mni')  # nearest neighbor interpolation is used for label maps
78
        >>> transformed = transform(image)  # "image" is now in the MNI space
79
80
    .. plot::
81
82
        import torchio as tio
83
        subject = tio.datasets.FPG()
84
        subject.remove_image('seg')
85
        resample = tio.Resample(8)
86
        t1_resampled = resample(subject.t1)
87
        subject.add_image(t1_resampled, 'Downsampled')
88
        subject.plot()
89
    """
90
91
    def __init__(
92
        self,
93
        target: TypeTarget = ONE_MILLIMITER_ISOTROPIC,
94
        image_interpolation: str = 'linear',
95
        label_interpolation: str = 'nearest',
96
        pre_affine_name: str | None = None,
97
        scalars_only: bool = False,
98
        antialias: bool = False,
99
        **kwargs,
100
    ):
101
        super().__init__(**kwargs)
102
        self.target = target
103
        self.image_interpolation = self.parse_interpolation(
104
            image_interpolation,
105
        )
106
        self.label_interpolation = self.parse_interpolation(
107
            label_interpolation,
108
        )
109
        self.pre_affine_name = pre_affine_name
110
        self.scalars_only = scalars_only
111
        self.antialias = antialias
112
        self.args_names = [
113
            'target',
114
            'image_interpolation',
115
            'label_interpolation',
116
            'pre_affine_name',
117
            'scalars_only',
118
            'antialias',
119
        ]
120
121
    @staticmethod
122
    def _parse_spacing(spacing: TypeSpacing) -> tuple[float, float, float]:
123
        result: Iterable
124
        if isinstance(spacing, Iterable) and len(spacing) == 3:
125
            result = spacing
126
        elif isinstance(spacing, Number):
127
            result = 3 * (spacing,)
128
        else:
129
            message = (
130
                'Target must be a string, a positive number'
131
                f' or a sequence of positive numbers, not {type(spacing)}'
132
            )
133
            raise ValueError(message)
134
        if np.any(np.array(spacing) <= 0):
135
            message = f'Spacing must be strictly positive, not "{spacing}"'
136
            raise ValueError(message)
137
        return result
138
139
    @staticmethod
140
    def check_affine(affine_name: str, image: Image):
141
        if not isinstance(affine_name, str):
142
            message = f'Affine name argument must be a string, not {type(affine_name)}'
143
            raise TypeError(message)
144
        if affine_name in image:
145
            matrix = image[affine_name]
146
            if not isinstance(matrix, (np.ndarray, torch.Tensor)):
147
                message = (
148
                    'The affine matrix must be a NumPy array or PyTorch'
149
                    f' tensor, not {type(matrix)}'
150
                )
151
                raise TypeError(message)
152
            if matrix.shape != (4, 4):
153
                message = f'The affine matrix shape must be (4, 4), not {matrix.shape}'
154
                raise ValueError(message)
155
156
    @staticmethod
157
    def check_affine_key_presence(affine_name: str, subject: Subject):
158
        for image in subject.get_images(intensity_only=False):
159
            if affine_name in image:
160
                return
161
        message = (
162
            f'An affine name was given ("{affine_name}"), but it was not found'
163
            ' in any image in the subject'
164
        )
165
        raise ValueError(message)
166
167
    def apply_transform(self, subject: Subject) -> Subject:
168
        use_pre_affine = self.pre_affine_name is not None
169
        if use_pre_affine:
170
            assert self.pre_affine_name is not None  # for mypy
171
            self.check_affine_key_presence(self.pre_affine_name, subject)
172
173
        for image in self.get_images(subject):
174
            # If the current image is the reference, don't resample it
175
            if self.target is image:
176
                continue
177
178
            # If the target is not a string, or is not an image in the subject,
179
            # do nothing
180
            try:
181
                target_image = subject[self.target]
182
                if target_image is image:
183
                    continue
184
            except (KeyError, TypeError, RuntimeError):
185
                pass
186
187
            # Choose interpolation
188
            if not isinstance(image, ScalarImage):
189
                if self.scalars_only:
190
                    continue
191
                interpolation = self.label_interpolation
192
            else:
193
                interpolation = self.image_interpolation
194
            interpolator = self.get_sitk_interpolator(interpolation)
195
196
            # Apply given affine matrix if found in image
197
            if use_pre_affine and self.pre_affine_name in image:
198
                assert self.pre_affine_name is not None  # for mypy
199
                self.check_affine(self.pre_affine_name, image)
200
                matrix = image[self.pre_affine_name]
201
                if isinstance(matrix, torch.Tensor):
202
                    matrix = matrix.numpy()
203
                image.affine = matrix @ image.affine
204
205
            floating_sitk = image.as_sitk(force_3d=True)
206
207
            resampler = self._get_resampler(
208
                interpolator,
209
                floating_sitk,
210
                subject,
211
                self.target,
212
            )
213
            if self.antialias and isinstance(image, ScalarImage):
214
                downsampling_factor = self._get_downsampling_factor(
215
                    floating_sitk,
216
                    resampler,
217
                )
218
                sigmas = self._get_sigmas(
219
                    downsampling_factor,
220
                    floating_sitk.GetSpacing(),
221
                )
222
                floating_sitk = self._smooth(floating_sitk, sigmas)
223
            resampled = resampler.Execute(floating_sitk)
224
225
            array, affine = sitk_to_nib(resampled)
226
            image.set_data(torch.as_tensor(array))
227
            image.affine = affine
228
        return subject
229
230
    @staticmethod
231
    def _smooth(
232
        image: sitk.Image,
233
        sigmas: np.ndarray,
234
        epsilon: float = 1e-9,
235
    ) -> sitk.Image:
236
        """Smooth the image with a Gaussian kernel.
237
238
        Args:
239
            image: Image to be smoothed.
240
            sigmas: Standard deviations of the Gaussian kernel for each
241
                dimension. If a value is NaN, no smoothing is applied in that
242
                dimension.
243
            epsilon: Small value to replace NaN values in sigmas, to avoid
244
                division-by-zero errors.
245
        """
246
247
        sigmas[np.isnan(sigmas)] = epsilon  # no smoothing in that dimension
248
        gaussian = sitk.SmoothingRecursiveGaussianImageFilter()
249
        gaussian.SetSigma(sigmas.tolist())
250
        smoothed = gaussian.Execute(image)
251
        return smoothed
252
253
    @staticmethod
254
    def _get_downsampling_factor(
255
        floating: sitk.Image,
256
        resampler: sitk.ResampleImageFilter,
257
    ) -> np.ndarray:
258
        """Get the downsampling factor for each dimension.
259
260
        The downsampling factor is the ratio between the output spacing and
261
        the input spacing. If the output spacing is smaller than the input
262
        spacing, the factor is set to NaN, meaning downsampling is not applied
263
        in that dimension.
264
265
        Args:
266
            floating: The input image to be resampled.
267
            resampler: The resampler that will be used to resample the image.
268
        """
269
        input_spacing = np.array(floating.GetSpacing())
270
        output_spacing = np.array(resampler.GetOutputSpacing())
271
        factors = output_spacing / input_spacing
272
        no_downsampling = factors <= 1
273
        factors[no_downsampling] = np.nan
274
        return factors
275
276
    def _get_resampler(
277
        self,
278
        interpolator: int,
279
        floating: sitk.Image,
280
        subject: Subject,
281
        target: TypeTarget,
282
    ) -> sitk.ResampleImageFilter:
283
        """Instantiate a SimpleITK resampler."""
284
        resampler = sitk.ResampleImageFilter()
285
        resampler.SetInterpolator(interpolator)
286
        self._set_resampler_reference(
287
            resampler,
288
            target,  # type: ignore[arg-type]
289
            floating,
290
            subject,
291
        )
292
        return resampler
293
294
    def _set_resampler_reference(
295
        self,
296
        resampler: sitk.ResampleImageFilter,
297
        target: TypeSpacing | TypePath | Image,
298
        floating_sitk,
299
        subject,
300
    ):
301
        # Target can be:
302
        # 1) An instance of torchio.Image
303
        # 2) An instance of pathlib.Path
304
        # 3) A string, which could be a path or an image in subject
305
        # 4) A number or sequence of numbers for spacing
306
        # 5) A tuple of shape, affine
307
        # The fourth case is the different one
308
        if isinstance(target, (str, Path, Image)):
309
            if isinstance(target, Image):
310
                # It's a TorchIO image
311
                image = target
312
            elif Path(target).is_file():
313
                # It's an existing file
314
                path = target
315
                image = ScalarImage(path)
316
            else:  # assume it's the name of an image in the subject
317
                try:
318
                    image = subject[target]
319
                except KeyError as error:
320
                    message = (
321
                        f'Image name "{target}" not found in subject.'
322
                        f' If "{target}" is a path, it does not exist or'
323
                        ' permission has been denied'
324
                    )
325
                    raise ValueError(message) from error
326
            self._set_resampler_from_shape_affine(
327
                resampler,
328
                image.spatial_shape,
329
                image.affine,
330
            )
331
        elif isinstance(target, Number):  # one number for target was passed
332
            self._set_resampler_from_spacing(resampler, target, floating_sitk)
333
        elif isinstance(target, Iterable) and len(target) == 2:
334
            assert not isinstance(target, str)  # for mypy
335
            shape, affine = target
336
            if not (isinstance(shape, Sized) and len(shape) == 3):
337
                message = (
338
                    'Target shape must be a sequence of three integers, but'
339
                    f' "{shape}" was passed'
340
                )
341
                raise RuntimeError(message)
342
            if not affine.shape == (4, 4):
343
                message = (
344
                    'Target affine must have shape (4, 4) but the following'
345
                    f' was passed:\n{shape}'
346
                )
347
                raise RuntimeError(message)
348
            self._set_resampler_from_shape_affine(
349
                resampler,
350
                shape,
351
                affine,
352
            )
353
        elif isinstance(target, Iterable) and len(target) == 3:
354
            self._set_resampler_from_spacing(resampler, target, floating_sitk)
355
        else:
356
            raise RuntimeError(f'Target not understood: "{target}"')
357
358
    def _set_resampler_from_shape_affine(self, resampler, shape, affine):
359
        origin, spacing, direction = get_sitk_metadata_from_ras_affine(affine)
360
        resampler.SetOutputDirection(direction)
361
        resampler.SetOutputOrigin(origin)
362
        resampler.SetOutputSpacing(spacing)
363
        resampler.SetSize(shape)
364
365
    def _set_resampler_from_spacing(self, resampler, target, floating_sitk):
366
        target_spacing = self._parse_spacing(target)
367
        reference_image = self.get_reference_image(
368
            floating_sitk,
369
            target_spacing,
370
        )
371
        resampler.SetReferenceImage(reference_image)
372
373
    @staticmethod
374
    def get_reference_image(
375
        floating_sitk: sitk.Image,
376
        spacing: TypeTripletFloat,
377
    ) -> sitk.Image:
378
        old_spacing = np.array(floating_sitk.GetSpacing(), dtype=float)
379
        new_spacing = np.array(spacing, dtype=float)
380
        old_size = np.array(floating_sitk.GetSize())
381
        old_last_index = old_size - 1
382
        old_last_index_lps = np.array(
383
            floating_sitk.TransformIndexToPhysicalPoint(old_last_index.tolist()),
384
            dtype=float,
385
        )
386
        old_origin_lps = np.array(floating_sitk.GetOrigin(), dtype=float)
387
        center_lps = (old_last_index_lps + old_origin_lps) / 2
388
        # We use floor to avoid extrapolation by keeping the extent of the
389
        # new image the same or smaller than the original.
390
        new_size = np.floor(old_size * old_spacing / new_spacing)
391
        # We keep singleton dimensions to avoid e.g. making 2D images 3D
392
        new_size[old_size == 1] = 1
393
        direction = np.asarray(floating_sitk.GetDirection(), dtype=float).reshape(3, 3)
394
        half_extent = (new_size - 1) / 2 * new_spacing
395
        new_origin_lps = (center_lps - direction @ half_extent).tolist()
396
        reference = sitk.Image(
397
            new_size.astype(int).tolist(),
398
            floating_sitk.GetPixelID(),
399
            floating_sitk.GetNumberOfComponentsPerPixel(),
400
        )
401
        reference.SetDirection(floating_sitk.GetDirection())
402
        reference.SetSpacing(new_spacing.tolist())
403
        reference.SetOrigin(new_origin_lps)
404
        return reference
405
406
    @staticmethod
407
    def _get_sigmas(downsampling_factor: np.ndarray, spacing: np.ndarray) -> np.ndarray:
408
        """Compute optimal standard deviation for Gaussian kernel.
409
410
        From Cardoso et al., `Scale factor point spread function matching:
411
        beyond aliasing in image resampling
412
        <https://link.springer.com/chapter/10.1007/978-3-319-24571-3_81>`_,
413
        MICCAI 2015.
414
415
        Args:
416
            downsampling_factor: Array with the downsampling factor for each
417
                dimension.
418
            spacing: Array with the spacing of the input image in mm.
419
        """
420
        k = downsampling_factor
421
        # Equation from top of page 678 of proceedings (4/9 in the PDF)
422
        variance = (k**2 - 1) * (2 * np.sqrt(2 * np.log(2))) ** (-2)
423
        sigma = spacing * np.sqrt(variance)
424
        return sigma
425