Passed
Pull Request — master (#640)
by Fernando
01:22
created

torchio.transforms.preprocessing.spatial.resample   A

Complexity

Total Complexity 40

Size/Duplication

Total Lines 295
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 40
eloc 184
dl 0
loc 295
rs 9.2
c 0
b 0
f 0

10 Methods

Rating   Name   Duplication   Size   Complexity  
C Resample.apply_transform() 0 43 9
A Resample._parse_spacing() 0 16 5
A Resample.get_reference_image() 0 23 1
A Resample.get_sigma() 0 11 1
A Resample.__init__() 0 21 1
A Resample.check_affine_key_presence() 0 10 3
D Resample._set_resampler_reference() 0 63 13
A Resample._set_resampler_from_spacing() 0 7 1
A Resample.check_affine() 0 22 5
A Resample._set_resampler_from_shape_affine() 0 6 1

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.preprocessing.spatial.resample often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from pathlib import Path
2
from numbers import Number
3
from typing import Union, Tuple, Optional, Sequence
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....data.io import sitk_to_nib, get_sitk_metadata_from_ras_affine
10
from ....data.subject import Subject
11
from ....typing import TypeTripletFloat, TypePath
12
from ....data.image import Image, ScalarImage
13
from ... import SpatialTransform
14
15
16
TypeSpacing = Union[float, Tuple[float, float, float]]
17
18
19
class Resample(SpatialTransform):
20
    """Change voxel spacing by resampling.
21
22
    Args:
23
        target: Tuple :math:`(s_h, s_w, s_d)`. If only one value
24
            :math:`n` is specified, then :math:`s_h = s_w = s_d = n`.
25
            If a string or :class:`~pathlib.Path` is given,
26
            all images will be resampled using the image
27
            with that name as reference or found at the path.
28
            An instance of :class:`~torchio.Image` can also be passed.
29
        pre_affine_name: Name of the *image key* (not subject key) storing an
30
            affine matrix that will be applied to the image header before
31
            resampling. If ``None``, the image is resampled with an identity
32
            transform. See usage in the example below.
33
        image_interpolation: See :ref:`Interpolation`.
34
        scalars_only: Apply only to instances of :class:`~torchio.ScalarImage`.
35
            Used internally by :class:`~torchio.transforms.RandomAnisotropy`.
36
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
37
            keyword arguments.
38
39
    Example:
40
        >>> import torch
41
        >>> import torchio as tio
42
        >>> transform = tio.Resample(1)                     # resample all images to 1mm iso
43
        >>> transform = tio.Resample((2, 2, 2))             # resample all images to 2mm iso
44
        >>> transform = tio.Resample('t1')                  # resample all images to 't1' image space
45
        >>> # Example: using a precomputed transform to MNI space
46
        >>> ref_path = tio.datasets.Colin27().t1.path  # this image is in the MNI space, so we can use it as reference/target
47
        >>> affine_matrix = tio.io.read_matrix('transform_to_mni.txt')  # from a NiftyReg registration. Would also work with e.g. .tfm from SimpleITK
48
        >>> image = tio.ScalarImage(tensor=torch.rand(1, 256, 256, 180), to_mni=affine_matrix)  # 'to_mni' is an arbitrary name
49
        >>> transform = tio.Resample(colin.t1.path, pre_affine_name='to_mni')  # nearest neighbor interpolation is used for label maps
50
        >>> transformed = transform(image)  # "image" is now in the MNI space
51
52
    .. plot::
53
54
        import torchio as tio
55
        subject = tio.datasets.FPG()
56
        subject.remove_image('seg')
57
        resample = tio.Resample(8)
58
        t1_resampled = resample(subject.t1)
59
        subject.add_image(t1_resampled, 'Downsampled')
60
        subject.plot()
61
62
    """  # noqa: E501
63
    def __init__(
64
            self,
65
            target: Union[TypeSpacing, str, Path, Image, None] = 1,
66
            image_interpolation: str = 'linear',
67
            pre_affine_name: Optional[str] = None,
68
            scalars_only: bool = False,
69
            **kwargs
70
            ):
71
        super().__init__(**kwargs)
72
        self.target = target
73
        self.image_interpolation = self.parse_interpolation(
74
            image_interpolation)
75
        self.pre_affine_name = pre_affine_name
76
        self.scalars_only = scalars_only
77
        self.target_shape = None
78
        self.target_affine = None
79
        self.args_names = (
80
            'target',
81
            'image_interpolation',
82
            'pre_affine_name',
83
            'scalars_only',
84
        )
85
86
    @staticmethod
87
    def _parse_spacing(spacing: TypeSpacing) -> Tuple[float, float, float]:
88
        if isinstance(spacing, Sequence) and len(spacing) == 3:
89
            result = spacing
90
        elif isinstance(spacing, Number):
91
            result = 3 * (spacing,)
92
        else:
93
            message = (
94
                'Target must be a string, a positive number'
95
                f' or a sequence of positive numbers, not {type(spacing)}'
96
            )
97
            raise ValueError(message)
98
        if np.any(np.array(spacing) <= 0):
99
            message = f'Spacing must be strictly positive, not "{spacing}"'
100
            raise ValueError(message)
101
        return result
102
103
    @staticmethod
104
    def check_affine(affine_name: str, image: Image):
105
        if not isinstance(affine_name, str):
106
            message = (
107
                'Affine name argument must be a string,'
108
                f' not {type(affine_name)}'
109
            )
110
            raise TypeError(message)
111
        if affine_name in image:
112
            matrix = image[affine_name]
113
            if not isinstance(matrix, (np.ndarray, torch.Tensor)):
114
                message = (
115
                    'The affine matrix must be a NumPy array or PyTorch'
116
                    f' tensor, not {type(matrix)}'
117
                )
118
                raise TypeError(message)
119
            if matrix.shape != (4, 4):
120
                message = (
121
                    'The affine matrix shape must be (4, 4),'
122
                    f' not {matrix.shape}'
123
                )
124
                raise ValueError(message)
125
126
    @staticmethod
127
    def check_affine_key_presence(affine_name: str, subject: Subject):
128
        for image in subject.get_images(intensity_only=False):
129
            if affine_name in image:
130
                return
131
        message = (
132
            f'An affine name was given ("{affine_name}"), but it was not found'
133
            ' in any image in the subject'
134
        )
135
        raise ValueError(message)
136
137
    def apply_transform(self, subject: Subject) -> Subject:
138
        use_pre_affine = self.pre_affine_name is not None
139
        if use_pre_affine:
140
            self.check_affine_key_presence(self.pre_affine_name, subject)
141
142
        for name, image in self.get_images_dict(subject).items():
143
            # Do not resample the reference image if there is one
144
            if name == self.target:
145
                continue
146
147
            # Choose interpolation
148
            if not isinstance(image, ScalarImage):
149
                if self.scalars_only:
150
                    continue
151
                interpolation = 'nearest'
152
            else:
153
                interpolation = self.image_interpolation
154
            interpolator = self.get_sitk_interpolator(interpolation)
155
156
            # Apply given affine matrix if found in image
157
            if use_pre_affine and self.pre_affine_name in image:
158
                self.check_affine(self.pre_affine_name, image)
159
                matrix = image[self.pre_affine_name]
160
                if isinstance(matrix, torch.Tensor):
161
                    matrix = matrix.numpy()
162
                image.affine = matrix @ image.affine
163
164
            floating_sitk = image.as_sitk(force_3d=True)
165
166
            resampler = sitk.ResampleImageFilter()
167
            resampler.SetInterpolator(interpolator)
168
            self._set_resampler_reference(
169
                resampler,
170
                self.target,
171
                floating_sitk,
172
                subject,
173
            )
174
            resampled = resampler.Execute(floating_sitk)
175
176
            array, affine = sitk_to_nib(resampled)
177
            image.set_data(torch.as_tensor(array))
178
            image.affine = affine
179
        return subject
180
181
    def _set_resampler_reference(
182
            self,
183
            resampler: sitk.ResampleImageFilter,
184
            target: Union[TypeSpacing, TypePath, Image],
185
            floating_sitk,
186
            subject,
187
            ):
188
        # Target can be:
189
        # 1) An instance of torchio.Image
190
        # 2) An instance of pathlib.Path
191
        # 3) A string, which could be a path or an image in subject
192
        # 3) A string, which could be a path or an image in subject
193
        # 4) A number or sequence of numbers for spacing
194
        # 5) A tuple of shape, affine
195
        # The fourth case is the different one
196
        if isinstance(target, (str, Path, Image)):
197
            if Path(target).is_file():
198
                # It's an existing file
199
                path = target
200
                image = ScalarImage(path)
201
            elif isinstance(target, Image):
202
                # It's a TorchIO image
203
                image = target
204
            else:  # assume it's an image in the subject
205
                try:
206
                    image = subject[target]
207
                except KeyError as error:
208
                    message = (
209
                        f'Image name "{target}" not found in subject.'
210
                        f' If "{target}" is a path, it does not exist or'
211
                        ' permission has been denied'
212
                    )
213
                    raise ValueError(message) from error
214
            self._set_resampler_from_shape_affine(
215
                resampler,
216
                image.spatial_shape,
217
                image.affine,
218
            )
219
        elif isinstance(target, Number):  # one number for target was passed
220
            self._set_resampler_from_spacing(resampler, target, floating_sitk)
221
        elif isinstance(target, Sequence) and len(target) == 2:
222
            shape, affine = target
223
            if not (isinstance(shape, Sequence) and len(shape) == 3):
224
                message = (
225
                    f'Target shape must be a sequence of three integers, but'
226
                    f' "{shape}" was passed'
227
                )
228
                raise RuntimeError(message)
229
            if not affine.shape == (4, 4):
230
                message = (
231
                    f'Target affine must have shape (4, 4) but the following'
232
                    f' was passed:\n{shape}'
233
                )
234
                raise RuntimeError(message)
235
            self._set_resampler_from_shape_affine(
236
                resampler,
237
                shape,
238
                affine,
239
            )
240
        elif isinstance(target, Sequence) and len(target) == 3:
241
            self._set_resampler_from_spacing(resampler, target, floating_sitk)
242
        else:
243
            raise RuntimeError(f'Target not understood: "{target}"')
244
245
    def _set_resampler_from_shape_affine(self, resampler, shape, affine):
246
        origin, spacing, direction = get_sitk_metadata_from_ras_affine(affine)
247
        resampler.SetOutputDirection(direction)
248
        resampler.SetOutputOrigin(origin)
249
        resampler.SetOutputSpacing(spacing)
250
        resampler.SetSize(shape)
251
252
    def _set_resampler_from_spacing(self, resampler, target, floating_sitk):
253
        target_spacing = self._parse_spacing(target)
254
        reference_image = self.get_reference_image(
255
            floating_sitk,
256
            target_spacing,
257
        )
258
        resampler.SetReferenceImage(reference_image)
259
260
    @staticmethod
261
    def get_reference_image(
262
            floating_sitk: sitk.Image,
263
            spacing: TypeTripletFloat,
264
            ) -> sitk.Image:
265
        old_spacing = np.array(floating_sitk.GetSpacing())
266
        new_spacing = np.array(spacing)
267
        old_size = np.array(floating_sitk.GetSize())
268
        new_size = old_size * old_spacing / new_spacing
269
        new_size = np.ceil(new_size).astype(np.uint16)
270
        new_size[old_size == 1] = 1  # keep singleton dimensions
271
        new_origin_index = 0.5 * (new_spacing / old_spacing - 1)
272
        new_origin_lps = floating_sitk.TransformContinuousIndexToPhysicalPoint(
273
            new_origin_index)
274
        reference = sitk.Image(
275
            new_size.tolist(),
276
            floating_sitk.GetPixelID(),
277
            floating_sitk.GetNumberOfComponentsPerPixel(),
278
        )
279
        reference.SetDirection(floating_sitk.GetDirection())
280
        reference.SetSpacing(new_spacing.tolist())
281
        reference.SetOrigin(new_origin_lps)
282
        return reference
283
284
    @staticmethod
285
    def get_sigma(downsampling_factor, spacing):
286
        """Compute optimal standard deviation for Gaussian kernel.
287
288
        From Cardoso et al., "Scale factor point spread function matching:
289
        beyond aliasing in image resampling", MICCAI 2015
290
        """
291
        k = downsampling_factor
292
        variance = (k ** 2 - 1 ** 2) * (2 * np.sqrt(2 * np.log(2))) ** (-2)
293
        sigma = spacing * np.sqrt(variance)
294
        return sigma
295