Resample.apply_transform()   C
last analyzed

Complexity

Conditions 11

Size

Total Lines 49
Code Lines 38

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 38
dl 0
loc 49
rs 5.4
c 0
b 0
f 0
cc 11
nop 2

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.preprocessing.spatial.resample.Resample.apply_transform() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from pathlib import Path
2
from numbers import Number
3
from typing import Union, Tuple, Optional
4
from collections.abc import Iterable
5
6
import torch
7
import numpy as np
8
import SimpleITK as sitk
9
10
from ....data.io import sitk_to_nib, get_sitk_metadata_from_ras_affine
11
from ....data.subject import Subject
12
from ....typing import TypeTripletFloat, TypePath
13
from ....data.image import Image, ScalarImage
14
from ... import SpatialTransform
15
16
17
TypeSpacing = Union[float, Tuple[float, float, float]]
18
19
20
class Resample(SpatialTransform):
21
    """Resample image to a different physical space.
22
23
    This is a powerful transform that can be used to change the image shape
24
    or spatial metadata, or to apply a spatial transformation.
25
26
    Args:
27
        target: Argument to define the output space. Can be one of:
28
29
            - Output spacing :math:`(s_w, s_h, s_d)`, in mm. If only one value
30
              :math:`s` is specified, then :math:`s_w = s_h = s_d = s`.
31
32
            - Path to an image that will be used as reference.
33
34
            - Instance of :class:`~torchio.Image`.
35
36
            - Name of an image key in the subject.
37
38
            - Tuple ``(spatial_shape, affine)`` defining the output space.
39
40
        pre_affine_name: Name of the *image key* (not subject key) storing an
41
            affine matrix that will be applied to the image header before
42
            resampling. If ``None``, the image is resampled with an identity
43
            transform. See usage in the example below.
44
        image_interpolation: See :ref:`Interpolation`.
45
        scalars_only: Apply only to instances of :class:`~torchio.ScalarImage`.
46
            Used internally by :class:`~torchio.transforms.RandomAnisotropy`.
47
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
48
            keyword arguments.
49
50
    Example:
51
        >>> import torch
52
        >>> import torchio as tio
53
        >>> transform = tio.Resample(1)                     # resample all images to 1mm iso
54
        >>> transform = tio.Resample((2, 2, 2))             # resample all images to 2mm iso
55
        >>> transform = tio.Resample('t1')                  # resample all images to 't1' image space
56
        >>> # Example: using a precomputed transform to MNI space
57
        >>> ref_path = tio.datasets.Colin27().t1.path  # this image is in the MNI space, so we can use it as reference/target
58
        >>> affine_matrix = tio.io.read_matrix('transform_to_mni.txt')  # from a NiftyReg registration. Would also work with e.g. .tfm from SimpleITK
59
        >>> image = tio.ScalarImage(tensor=torch.rand(1, 256, 256, 180), to_mni=affine_matrix)  # 'to_mni' is an arbitrary name
60
        >>> transform = tio.Resample(colin.t1.path, pre_affine_name='to_mni')  # nearest neighbor interpolation is used for label maps
61
        >>> transformed = transform(image)  # "image" is now in the MNI space
62
63
    .. plot::
64
65
        import torchio as tio
66
        subject = tio.datasets.FPG()
67
        subject.remove_image('seg')
68
        resample = tio.Resample(8)
69
        t1_resampled = resample(subject.t1)
70
        subject.add_image(t1_resampled, 'Downsampled')
71
        subject.plot()
72
73
    """  # noqa: E501
74
    def __init__(
75
            self,
76
            target: Union[TypeSpacing, str, Path, Image, None] = 1,
77
            image_interpolation: str = 'linear',
78
            pre_affine_name: Optional[str] = None,
79
            scalars_only: bool = False,
80
            **kwargs
81
            ):
82
        super().__init__(**kwargs)
83
        self.target = target
84
        self.image_interpolation = self.parse_interpolation(
85
            image_interpolation)
86
        self.pre_affine_name = pre_affine_name
87
        self.scalars_only = scalars_only
88
        self.args_names = (
89
            'target',
90
            'image_interpolation',
91
            'pre_affine_name',
92
            'scalars_only',
93
        )
94
95
    @staticmethod
96
    def _parse_spacing(spacing: TypeSpacing) -> Tuple[float, float, float]:
97
        if isinstance(spacing, Iterable) and len(spacing) == 3:
98
            result = spacing
99
        elif isinstance(spacing, Number):
100
            result = 3 * (spacing,)
101
        else:
102
            message = (
103
                'Target must be a string, a positive number'
104
                f' or a sequence of positive numbers, not {type(spacing)}'
105
            )
106
            raise ValueError(message)
107
        if np.any(np.array(spacing) <= 0):
108
            message = f'Spacing must be strictly positive, not "{spacing}"'
109
            raise ValueError(message)
110
        return result
111
112
    @staticmethod
113
    def check_affine(affine_name: str, image: Image):
114
        if not isinstance(affine_name, str):
115
            message = (
116
                'Affine name argument must be a string,'
117
                f' not {type(affine_name)}'
118
            )
119
            raise TypeError(message)
120
        if affine_name in image:
121
            matrix = image[affine_name]
122
            if not isinstance(matrix, (np.ndarray, torch.Tensor)):
123
                message = (
124
                    'The affine matrix must be a NumPy array or PyTorch'
125
                    f' tensor, not {type(matrix)}'
126
                )
127
                raise TypeError(message)
128
            if matrix.shape != (4, 4):
129
                message = (
130
                    'The affine matrix shape must be (4, 4),'
131
                    f' not {matrix.shape}'
132
                )
133
                raise ValueError(message)
134
135
    @staticmethod
136
    def check_affine_key_presence(affine_name: str, subject: Subject):
137
        for image in subject.get_images(intensity_only=False):
138
            if affine_name in image:
139
                return
140
        message = (
141
            f'An affine name was given ("{affine_name}"), but it was not found'
142
            ' in any image in the subject'
143
        )
144
        raise ValueError(message)
145
146
    def apply_transform(self, subject: Subject) -> Subject:
147
        use_pre_affine = self.pre_affine_name is not None
148
        if use_pre_affine:
149
            self.check_affine_key_presence(self.pre_affine_name, subject)
150
151
        for image in self.get_images(subject):
152
            # Do not resample the reference image if it is in the subject
153
            if self.target is image:
154
                continue
155
            try:
156
                target_image = subject[self.target]
157
                if target_image is image:
158
                    continue
159
            except (KeyError, TypeError):
160
                pass
161
162
            # Choose interpolation
163
            if not isinstance(image, ScalarImage):
164
                if self.scalars_only:
165
                    continue
166
                interpolation = 'nearest'
167
            else:
168
                interpolation = self.image_interpolation
169
            interpolator = self.get_sitk_interpolator(interpolation)
170
171
            # Apply given affine matrix if found in image
172
            if use_pre_affine and self.pre_affine_name in image:
173
                self.check_affine(self.pre_affine_name, image)
174
                matrix = image[self.pre_affine_name]
175
                if isinstance(matrix, torch.Tensor):
176
                    matrix = matrix.numpy()
177
                image.affine = matrix @ image.affine
178
179
            floating_sitk = image.as_sitk(force_3d=True)
180
181
            resampler = sitk.ResampleImageFilter()
182
            resampler.SetInterpolator(interpolator)
183
            self._set_resampler_reference(
184
                resampler,
185
                self.target,
186
                floating_sitk,
187
                subject,
188
            )
189
            resampled = resampler.Execute(floating_sitk)
190
191
            array, affine = sitk_to_nib(resampled)
192
            image.set_data(torch.as_tensor(array))
193
            image.affine = affine
194
        return subject
195
196
    def _set_resampler_reference(
197
            self,
198
            resampler: sitk.ResampleImageFilter,
199
            target: Union[TypeSpacing, TypePath, Image],
200
            floating_sitk,
201
            subject,
202
            ):
203
        # Target can be:
204
        # 1) An instance of torchio.Image
205
        # 2) An instance of pathlib.Path
206
        # 3) A string, which could be a path or an image in subject
207
        # 3) A string, which could be a path or an image in subject
208
        # 4) A number or sequence of numbers for spacing
209
        # 5) A tuple of shape, affine
210
        # The fourth case is the different one
211
        if isinstance(target, (str, Path, Image)):
212
            if isinstance(target, Image):
213
                # It's a TorchIO image
214
                image = target
215
            elif Path(target).is_file():
216
                # It's an existing file
217
                path = target
218
                image = ScalarImage(path)
219
            else:  # assume it's the name of an image in the subject
220
                try:
221
                    image = subject[target]
222
                except KeyError as error:
223
                    message = (
224
                        f'Image name "{target}" not found in subject.'
225
                        f' If "{target}" is a path, it does not exist or'
226
                        ' permission has been denied'
227
                    )
228
                    raise ValueError(message) from error
229
            self._set_resampler_from_shape_affine(
230
                resampler,
231
                image.spatial_shape,
232
                image.affine,
233
            )
234
        elif isinstance(target, Number):  # one number for target was passed
235
            self._set_resampler_from_spacing(resampler, target, floating_sitk)
236
        elif isinstance(target, Iterable) and len(target) == 2:
237
            shape, affine = target
238
            if not (isinstance(shape, Iterable) and len(shape) == 3):
239
                message = (
240
                    f'Target shape must be a sequence of three integers, but'
241
                    f' "{shape}" was passed'
242
                )
243
                raise RuntimeError(message)
244
            if not affine.shape == (4, 4):
245
                message = (
246
                    f'Target affine must have shape (4, 4) but the following'
247
                    f' was passed:\n{shape}'
248
                )
249
                raise RuntimeError(message)
250
            self._set_resampler_from_shape_affine(
251
                resampler,
252
                shape,
253
                affine,
254
            )
255
        elif isinstance(target, Iterable) and len(target) == 3:
256
            self._set_resampler_from_spacing(resampler, target, floating_sitk)
257
        else:
258
            raise RuntimeError(f'Target not understood: "{target}"')
259
260
    def _set_resampler_from_shape_affine(self, resampler, shape, affine):
261
        origin, spacing, direction = get_sitk_metadata_from_ras_affine(affine)
262
        resampler.SetOutputDirection(direction)
263
        resampler.SetOutputOrigin(origin)
264
        resampler.SetOutputSpacing(spacing)
265
        resampler.SetSize(shape)
266
267
    def _set_resampler_from_spacing(self, resampler, target, floating_sitk):
268
        target_spacing = self._parse_spacing(target)
269
        reference_image = self.get_reference_image(
270
            floating_sitk,
271
            target_spacing,
272
        )
273
        resampler.SetReferenceImage(reference_image)
274
275
    @staticmethod
276
    def get_reference_image(
277
            floating_sitk: sitk.Image,
278
            spacing: TypeTripletFloat,
279
            ) -> sitk.Image:
280
        old_spacing = np.array(floating_sitk.GetSpacing())
281
        new_spacing = np.array(spacing)
282
        old_size = np.array(floating_sitk.GetSize())
283
        new_size = old_size * old_spacing / new_spacing
284
        new_size = np.ceil(new_size).astype(np.uint16)
285
        new_size[old_size == 1] = 1  # keep singleton dimensions
286
        new_origin_index = 0.5 * (new_spacing / old_spacing - 1)
287
        new_origin_lps = floating_sitk.TransformContinuousIndexToPhysicalPoint(
288
            new_origin_index)
289
        reference = sitk.Image(
290
            new_size.tolist(),
291
            floating_sitk.GetPixelID(),
292
            floating_sitk.GetNumberOfComponentsPerPixel(),
293
        )
294
        reference.SetDirection(floating_sitk.GetDirection())
295
        reference.SetSpacing(new_spacing.tolist())
296
        reference.SetOrigin(new_origin_lps)
297
        return reference
298
299
    @staticmethod
300
    def get_sigma(downsampling_factor, spacing):
301
        """Compute optimal standard deviation for Gaussian kernel.
302
303
        From Cardoso et al., "Scale factor point spread function matching:
304
        beyond aliasing in image resampling", MICCAI 2015
305
        """
306
        k = downsampling_factor
307
        variance = (k ** 2 - 1 ** 2) * (2 * np.sqrt(2 * np.log(2))) ** (-2)
308
        sigma = spacing * np.sqrt(variance)
309
        return sigma
310