Passed
Push — master ( 0bf8ef...e85db2 )
by Fernando
01:12
created

Resample._parse_spacing()   A

Complexity

Conditions 5

Size

Total Lines 16
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 16
rs 9.2833
c 0
b 0
f 0
cc 5
nop 1
1
from pathlib import Path
2
from numbers import Number
3
from typing import Union, Tuple, Optional, Sequence
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....data.io import sitk_to_nib, get_sitk_metadata_from_ras_affine
10
from ....data.subject import Subject
11
from ....typing import TypeTripletFloat, TypePath
12
from ....data.image import Image, ScalarImage
13
from ... import SpatialTransform
14
15
16
TypeSpacing = Union[float, Tuple[float, float, float]]
17
18
19
class Resample(SpatialTransform):
20
    """Change voxel spacing by resampling.
21
22
    Args:
23
        target: Argument to define the output space. Can be one of:
24
25
            - Output spacing :math:`(s_w, s_h, s_d)`, in mm. If only one value
26
              :math:`s` is specified, then :math:`s_w = s_h = s_d = s`.
27
28
            - Path to an image that will be used as reference.
29
30
            - Instance of :class:`~torchio.Image`.
31
32
            - Name of an image key in the subject.
33
34
            - Tuple ``(spaial_shape, affine)`` defining the output space.
35
36
        pre_affine_name: Name of the *image key* (not subject key) storing an
37
            affine matrix that will be applied to the image header before
38
            resampling. If ``None``, the image is resampled with an identity
39
            transform. See usage in the example below.
40
        image_interpolation: See :ref:`Interpolation`.
41
        scalars_only: Apply only to instances of :class:`~torchio.ScalarImage`.
42
            Used internally by :class:`~torchio.transforms.RandomAnisotropy`.
43
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
44
            keyword arguments.
45
46
    Example:
47
        >>> import torch
48
        >>> import torchio as tio
49
        >>> transform = tio.Resample(1)                     # resample all images to 1mm iso
50
        >>> transform = tio.Resample((2, 2, 2))             # resample all images to 2mm iso
51
        >>> transform = tio.Resample('t1')                  # resample all images to 't1' image space
52
        >>> # Example: using a precomputed transform to MNI space
53
        >>> ref_path = tio.datasets.Colin27().t1.path  # this image is in the MNI space, so we can use it as reference/target
54
        >>> affine_matrix = tio.io.read_matrix('transform_to_mni.txt')  # from a NiftyReg registration. Would also work with e.g. .tfm from SimpleITK
55
        >>> image = tio.ScalarImage(tensor=torch.rand(1, 256, 256, 180), to_mni=affine_matrix)  # 'to_mni' is an arbitrary name
56
        >>> transform = tio.Resample(colin.t1.path, pre_affine_name='to_mni')  # nearest neighbor interpolation is used for label maps
57
        >>> transformed = transform(image)  # "image" is now in the MNI space
58
59
    .. plot::
60
61
        import torchio as tio
62
        subject = tio.datasets.FPG()
63
        subject.remove_image('seg')
64
        resample = tio.Resample(8)
65
        t1_resampled = resample(subject.t1)
66
        subject.add_image(t1_resampled, 'Downsampled')
67
        subject.plot()
68
69
    """  # noqa: E501
70
    def __init__(
71
            self,
72
            target: Union[TypeSpacing, str, Path, Image, None] = 1,
73
            image_interpolation: str = 'linear',
74
            pre_affine_name: Optional[str] = None,
75
            scalars_only: bool = False,
76
            **kwargs
77
            ):
78
        super().__init__(**kwargs)
79
        self.target = target
80
        self.image_interpolation = self.parse_interpolation(
81
            image_interpolation)
82
        self.pre_affine_name = pre_affine_name
83
        self.scalars_only = scalars_only
84
        self.args_names = (
85
            'target',
86
            'image_interpolation',
87
            'pre_affine_name',
88
            'scalars_only',
89
        )
90
91
    @staticmethod
92
    def _parse_spacing(spacing: TypeSpacing) -> Tuple[float, float, float]:
93
        if isinstance(spacing, Sequence) and len(spacing) == 3:
94
            result = spacing
95
        elif isinstance(spacing, Number):
96
            result = 3 * (spacing,)
97
        else:
98
            message = (
99
                'Target must be a string, a positive number'
100
                f' or a sequence of positive numbers, not {type(spacing)}'
101
            )
102
            raise ValueError(message)
103
        if np.any(np.array(spacing) <= 0):
104
            message = f'Spacing must be strictly positive, not "{spacing}"'
105
            raise ValueError(message)
106
        return result
107
108
    @staticmethod
109
    def check_affine(affine_name: str, image: Image):
110
        if not isinstance(affine_name, str):
111
            message = (
112
                'Affine name argument must be a string,'
113
                f' not {type(affine_name)}'
114
            )
115
            raise TypeError(message)
116
        if affine_name in image:
117
            matrix = image[affine_name]
118
            if not isinstance(matrix, (np.ndarray, torch.Tensor)):
119
                message = (
120
                    'The affine matrix must be a NumPy array or PyTorch'
121
                    f' tensor, not {type(matrix)}'
122
                )
123
                raise TypeError(message)
124
            if matrix.shape != (4, 4):
125
                message = (
126
                    'The affine matrix shape must be (4, 4),'
127
                    f' not {matrix.shape}'
128
                )
129
                raise ValueError(message)
130
131
    @staticmethod
132
    def check_affine_key_presence(affine_name: str, subject: Subject):
133
        for image in subject.get_images(intensity_only=False):
134
            if affine_name in image:
135
                return
136
        message = (
137
            f'An affine name was given ("{affine_name}"), but it was not found'
138
            ' in any image in the subject'
139
        )
140
        raise ValueError(message)
141
142
    def apply_transform(self, subject: Subject) -> Subject:
143
        use_pre_affine = self.pre_affine_name is not None
144
        if use_pre_affine:
145
            self.check_affine_key_presence(self.pre_affine_name, subject)
146
147
        for image in self.get_images(subject):
148
            # Do not resample the reference image if it is in the subject
149
            if self.target is image:
150
                continue
151
            try:
152
                target_image = subject[self.target]
153
                if target_image is image:
154
                    continue
155
            except (KeyError, TypeError):
156
                pass
157
158
            # Choose interpolation
159
            if not isinstance(image, ScalarImage):
160
                if self.scalars_only:
161
                    continue
162
                interpolation = 'nearest'
163
            else:
164
                interpolation = self.image_interpolation
165
            interpolator = self.get_sitk_interpolator(interpolation)
166
167
            # Apply given affine matrix if found in image
168
            if use_pre_affine and self.pre_affine_name in image:
169
                self.check_affine(self.pre_affine_name, image)
170
                matrix = image[self.pre_affine_name]
171
                if isinstance(matrix, torch.Tensor):
172
                    matrix = matrix.numpy()
173
                image.affine = matrix @ image.affine
174
175
            floating_sitk = image.as_sitk(force_3d=True)
176
177
            resampler = sitk.ResampleImageFilter()
178
            resampler.SetInterpolator(interpolator)
179
            self._set_resampler_reference(
180
                resampler,
181
                self.target,
182
                floating_sitk,
183
                subject,
184
            )
185
            resampled = resampler.Execute(floating_sitk)
186
187
            array, affine = sitk_to_nib(resampled)
188
            image.set_data(torch.as_tensor(array))
189
            image.affine = affine
190
        return subject
191
192
    def _set_resampler_reference(
193
            self,
194
            resampler: sitk.ResampleImageFilter,
195
            target: Union[TypeSpacing, TypePath, Image],
196
            floating_sitk,
197
            subject,
198
            ):
199
        # Target can be:
200
        # 1) An instance of torchio.Image
201
        # 2) An instance of pathlib.Path
202
        # 3) A string, which could be a path or an image in subject
203
        # 3) A string, which could be a path or an image in subject
204
        # 4) A number or sequence of numbers for spacing
205
        # 5) A tuple of shape, affine
206
        # The fourth case is the different one
207
        if isinstance(target, (str, Path, Image)):
208
            if isinstance(target, Image):
209
                # It's a TorchIO image
210
                image = target
211
            elif Path(target).is_file():
212
                # It's an existing file
213
                path = target
214
                image = ScalarImage(path)
215
            else:  # assume it's the name of an image in the subject
216
                try:
217
                    image = subject[target]
218
                except KeyError as error:
219
                    message = (
220
                        f'Image name "{target}" not found in subject.'
221
                        f' If "{target}" is a path, it does not exist or'
222
                        ' permission has been denied'
223
                    )
224
                    raise ValueError(message) from error
225
            self._set_resampler_from_shape_affine(
226
                resampler,
227
                image.spatial_shape,
228
                image.affine,
229
            )
230
        elif isinstance(target, Number):  # one number for target was passed
231
            self._set_resampler_from_spacing(resampler, target, floating_sitk)
232
        elif isinstance(target, Sequence) and len(target) == 2:
233
            shape, affine = target
234
            if not (isinstance(shape, Sequence) and len(shape) == 3):
235
                message = (
236
                    f'Target shape must be a sequence of three integers, but'
237
                    f' "{shape}" was passed'
238
                )
239
                raise RuntimeError(message)
240
            if not affine.shape == (4, 4):
241
                message = (
242
                    f'Target affine must have shape (4, 4) but the following'
243
                    f' was passed:\n{shape}'
244
                )
245
                raise RuntimeError(message)
246
            self._set_resampler_from_shape_affine(
247
                resampler,
248
                shape,
249
                affine,
250
            )
251
        elif isinstance(target, Sequence) and len(target) == 3:
252
            self._set_resampler_from_spacing(resampler, target, floating_sitk)
253
        else:
254
            raise RuntimeError(f'Target not understood: "{target}"')
255
256
    def _set_resampler_from_shape_affine(self, resampler, shape, affine):
257
        origin, spacing, direction = get_sitk_metadata_from_ras_affine(affine)
258
        resampler.SetOutputDirection(direction)
259
        resampler.SetOutputOrigin(origin)
260
        resampler.SetOutputSpacing(spacing)
261
        resampler.SetSize(shape)
262
263
    def _set_resampler_from_spacing(self, resampler, target, floating_sitk):
264
        target_spacing = self._parse_spacing(target)
265
        reference_image = self.get_reference_image(
266
            floating_sitk,
267
            target_spacing,
268
        )
269
        resampler.SetReferenceImage(reference_image)
270
271
    @staticmethod
272
    def get_reference_image(
273
            floating_sitk: sitk.Image,
274
            spacing: TypeTripletFloat,
275
            ) -> sitk.Image:
276
        old_spacing = np.array(floating_sitk.GetSpacing())
277
        new_spacing = np.array(spacing)
278
        old_size = np.array(floating_sitk.GetSize())
279
        new_size = old_size * old_spacing / new_spacing
280
        new_size = np.ceil(new_size).astype(np.uint16)
281
        new_size[old_size == 1] = 1  # keep singleton dimensions
282
        new_origin_index = 0.5 * (new_spacing / old_spacing - 1)
283
        new_origin_lps = floating_sitk.TransformContinuousIndexToPhysicalPoint(
284
            new_origin_index)
285
        reference = sitk.Image(
286
            new_size.tolist(),
287
            floating_sitk.GetPixelID(),
288
            floating_sitk.GetNumberOfComponentsPerPixel(),
289
        )
290
        reference.SetDirection(floating_sitk.GetDirection())
291
        reference.SetSpacing(new_spacing.tolist())
292
        reference.SetOrigin(new_origin_lps)
293
        return reference
294
295
    @staticmethod
296
    def get_sigma(downsampling_factor, spacing):
297
        """Compute optimal standard deviation for Gaussian kernel.
298
299
        From Cardoso et al., "Scale factor point spread function matching:
300
        beyond aliasing in image resampling", MICCAI 2015
301
        """
302
        k = downsampling_factor
303
        variance = (k ** 2 - 1 ** 2) * (2 * np.sqrt(2 * np.log(2))) ** (-2)
304
        sigma = spacing * np.sqrt(variance)
305
        return sigma
306