Passed
Pull Request — master (#640)
by Fernando
01:15
created

Resample.get_reference_image()   A

Complexity

Conditions 1

Size

Total Lines 23
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 21
nop 2
dl 0
loc 23
rs 9.376
c 0
b 0
f 0
1
from pathlib import Path
2
from numbers import Number
3
from typing import Union, Tuple, Optional, Sequence
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....data.io import sitk_to_nib, get_sitk_metadata_from_ras_affine
10
from ....data.subject import Subject
11
from ....typing import TypeTripletFloat, TypePath
12
from ....data.image import Image, ScalarImage
13
from ... import SpatialTransform
14
15
16
TypeSpacing = Union[float, Tuple[float, float, float]]
17
18
19
class Resample(SpatialTransform):
20
    """Change voxel spacing by resampling.
21
22
    Args:
23
        target: Tuple :math:`(s_h, s_w, s_d)`. If only one value
24
            :math:`n` is specified, then :math:`s_h = s_w = s_d = n`.
25
            If a string or :class:`~pathlib.Path` is given,
26
            all images will be resampled using the image
27
            with that name as reference or found at the path.
28
            An instance of :class:`~torchio.Image` can also be passed.
29
        pre_affine_name: Name of the *image key* (not subject key) storing an
30
            affine matrix that will be applied to the image header before
31
            resampling. If ``None``, the image is resampled with an identity
32
            transform. See usage in the example below.
33
        image_interpolation: See :ref:`Interpolation`.
34
        scalars_only: Apply only to instances of :class:`~torchio.ScalarImage`.
35
            Used internally by :class:`~torchio.transforms.RandomAnisotropy`.
36
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
37
            keyword arguments.
38
39
    Example:
40
        >>> import torch
41
        >>> import torchio as tio
42
        >>> transform = tio.Resample(1)                     # resample all images to 1mm iso
43
        >>> transform = tio.Resample((2, 2, 2))             # resample all images to 2mm iso
44
        >>> transform = tio.Resample('t1')                  # resample all images to 't1' image space
45
        >>> # Example: using a precomputed transform to MNI space
46
        >>> ref_path = tio.datasets.Colin27().t1.path  # this image is in the MNI space, so we can use it as reference/target
47
        >>> affine_matrix = tio.io.read_matrix('transform_to_mni.txt')  # from a NiftyReg registration. Would also work with e.g. .tfm from SimpleITK
48
        >>> image = tio.ScalarImage(tensor=torch.rand(1, 256, 256, 180), to_mni=affine_matrix)  # 'to_mni' is an arbitrary name
49
        >>> transform = tio.Resample(colin.t1.path, pre_affine_name='to_mni')  # nearest neighbor interpolation is used for label maps
50
        >>> transformed = transform(image)  # "image" is now in the MNI space
51
52
    .. plot::
53
54
        import torchio as tio
55
        subject = tio.datasets.FPG()
56
        subject.remove_image('seg')
57
        resample = tio.Resample(8)
58
        t1_resampled = resample(subject.t1)
59
        subject.add_image(t1_resampled, 'Downsampled')
60
        subject.plot()
61
62
    """  # noqa: E501
63
    def __init__(
64
            self,
65
            target: Union[TypeSpacing, str, Path, Image, None] = 1,
66
            image_interpolation: str = 'linear',
67
            pre_affine_name: Optional[str] = None,
68
            scalars_only: bool = False,
69
            **kwargs
70
            ):
71
        super().__init__(**kwargs)
72
        self.target = target
73
        self.image_interpolation = self.parse_interpolation(
74
            image_interpolation)
75
        self.pre_affine_name = pre_affine_name
76
        self.scalars_only = scalars_only
77
        self.args_names = (
78
            'target',
79
            'image_interpolation',
80
            'pre_affine_name',
81
            'scalars_only',
82
        )
83
84
    @staticmethod
85
    def _parse_spacing(spacing: TypeSpacing) -> Tuple[float, float, float]:
86
        if isinstance(spacing, Sequence) and len(spacing) == 3:
87
            result = spacing
88
        elif isinstance(spacing, Number):
89
            result = 3 * (spacing,)
90
        else:
91
            message = (
92
                'Target must be a string, a positive number'
93
                f' or a sequence of positive numbers, not {type(spacing)}'
94
            )
95
            raise ValueError(message)
96
        if np.any(np.array(spacing) <= 0):
97
            message = f'Spacing must be strictly positive, not "{spacing}"'
98
            raise ValueError(message)
99
        return result
100
101
    @staticmethod
102
    def check_affine(affine_name: str, image: Image):
103
        if not isinstance(affine_name, str):
104
            message = (
105
                'Affine name argument must be a string,'
106
                f' not {type(affine_name)}'
107
            )
108
            raise TypeError(message)
109
        if affine_name in image:
110
            matrix = image[affine_name]
111
            if not isinstance(matrix, (np.ndarray, torch.Tensor)):
112
                message = (
113
                    'The affine matrix must be a NumPy array or PyTorch'
114
                    f' tensor, not {type(matrix)}'
115
                )
116
                raise TypeError(message)
117
            if matrix.shape != (4, 4):
118
                message = (
119
                    'The affine matrix shape must be (4, 4),'
120
                    f' not {matrix.shape}'
121
                )
122
                raise ValueError(message)
123
124
    @staticmethod
125
    def check_affine_key_presence(affine_name: str, subject: Subject):
126
        for image in subject.get_images(intensity_only=False):
127
            if affine_name in image:
128
                return
129
        message = (
130
            f'An affine name was given ("{affine_name}"), but it was not found'
131
            ' in any image in the subject'
132
        )
133
        raise ValueError(message)
134
135
    def apply_transform(self, subject: Subject) -> Subject:
136
        use_pre_affine = self.pre_affine_name is not None
137
        if use_pre_affine:
138
            self.check_affine_key_presence(self.pre_affine_name, subject)
139
140
        for name, image in self.get_images_dict(subject).items():
141
            # Do not resample the reference image if there is one
142
            if name == self.target:
143
                continue
144
145
            # Choose interpolation
146
            if not isinstance(image, ScalarImage):
147
                if self.scalars_only:
148
                    continue
149
                interpolation = 'nearest'
150
            else:
151
                interpolation = self.image_interpolation
152
            interpolator = self.get_sitk_interpolator(interpolation)
153
154
            # Apply given affine matrix if found in image
155
            if use_pre_affine and self.pre_affine_name in image:
156
                self.check_affine(self.pre_affine_name, image)
157
                matrix = image[self.pre_affine_name]
158
                if isinstance(matrix, torch.Tensor):
159
                    matrix = matrix.numpy()
160
                image.affine = matrix @ image.affine
161
162
            floating_sitk = image.as_sitk(force_3d=True)
163
164
            resampler = sitk.ResampleImageFilter()
165
            resampler.SetInterpolator(interpolator)
166
            self._set_resampler_reference(
167
                resampler,
168
                self.target,
169
                floating_sitk,
170
                subject,
171
            )
172
            resampled = resampler.Execute(floating_sitk)
173
174
            array, affine = sitk_to_nib(resampled)
175
            image.set_data(torch.as_tensor(array))
176
            image.affine = affine
177
        return subject
178
179
    def _set_resampler_reference(
180
            self,
181
            resampler: sitk.ResampleImageFilter,
182
            target: Union[TypeSpacing, TypePath, Image],
183
            floating_sitk,
184
            subject,
185
            ):
186
        # Target can be:
187
        # 1) An instance of torchio.Image
188
        # 2) An instance of pathlib.Path
189
        # 3) A string, which could be a path or an image in subject
190
        # 3) A string, which could be a path or an image in subject
191
        # 4) A number or sequence of numbers for spacing
192
        # 5) A tuple of shape, affine
193
        # The fourth case is the different one
194
        if isinstance(target, (str, Path, Image)):
195
            if Path(target).is_file():
196
                # It's an existing file
197
                path = target
198
                image = ScalarImage(path)
199
            elif isinstance(target, Image):
200
                # It's a TorchIO image
201
                image = target
202
            else:  # assume it's an image in the subject
203
                try:
204
                    image = subject[target]
205
                except KeyError as error:
206
                    message = (
207
                        f'Image name "{target}" not found in subject.'
208
                        f' If "{target}" is a path, it does not exist or'
209
                        ' permission has been denied'
210
                    )
211
                    raise ValueError(message) from error
212
            self._set_resampler_from_shape_affine(
213
                resampler,
214
                image.spatial_shape,
215
                image.affine,
216
            )
217
        elif isinstance(target, Number):  # one number for target was passed
218
            self._set_resampler_from_spacing(resampler, target, floating_sitk)
219
        elif isinstance(target, Sequence) and len(target) == 2:
220
            shape, affine = target
221
            if not (isinstance(shape, Sequence) and len(shape) == 3):
222
                message = (
223
                    f'Target shape must be a sequence of three integers, but'
224
                    f' "{shape}" was passed'
225
                )
226
                raise RuntimeError(message)
227
            if not affine.shape == (4, 4):
228
                message = (
229
                    f'Target affine must have shape (4, 4) but the following'
230
                    f' was passed:\n{shape}'
231
                )
232
                raise RuntimeError(message)
233
            self._set_resampler_from_shape_affine(
234
                resampler,
235
                shape,
236
                affine,
237
            )
238
        elif isinstance(target, Sequence) and len(target) == 3:
239
            self._set_resampler_from_spacing(resampler, target, floating_sitk)
240
        else:
241
            raise RuntimeError(f'Target not understood: "{target}"')
242
243
    def _set_resampler_from_shape_affine(self, resampler, shape, affine):
244
        origin, spacing, direction = get_sitk_metadata_from_ras_affine(affine)
245
        resampler.SetOutputDirection(direction)
246
        resampler.SetOutputOrigin(origin)
247
        resampler.SetOutputSpacing(spacing)
248
        resampler.SetSize(shape)
249
250
    def _set_resampler_from_spacing(self, resampler, target, floating_sitk):
251
        target_spacing = self._parse_spacing(target)
252
        reference_image = self.get_reference_image(
253
            floating_sitk,
254
            target_spacing,
255
        )
256
        resampler.SetReferenceImage(reference_image)
257
258
    @staticmethod
259
    def get_reference_image(
260
            floating_sitk: sitk.Image,
261
            spacing: TypeTripletFloat,
262
            ) -> sitk.Image:
263
        old_spacing = np.array(floating_sitk.GetSpacing())
264
        new_spacing = np.array(spacing)
265
        old_size = np.array(floating_sitk.GetSize())
266
        new_size = old_size * old_spacing / new_spacing
267
        new_size = np.ceil(new_size).astype(np.uint16)
268
        new_size[old_size == 1] = 1  # keep singleton dimensions
269
        new_origin_index = 0.5 * (new_spacing / old_spacing - 1)
270
        new_origin_lps = floating_sitk.TransformContinuousIndexToPhysicalPoint(
271
            new_origin_index)
272
        reference = sitk.Image(
273
            new_size.tolist(),
274
            floating_sitk.GetPixelID(),
275
            floating_sitk.GetNumberOfComponentsPerPixel(),
276
        )
277
        reference.SetDirection(floating_sitk.GetDirection())
278
        reference.SetSpacing(new_spacing.tolist())
279
        reference.SetOrigin(new_origin_lps)
280
        return reference
281
282
    @staticmethod
283
    def get_sigma(downsampling_factor, spacing):
284
        """Compute optimal standard deviation for Gaussian kernel.
285
286
        From Cardoso et al., "Scale factor point spread function matching:
287
        beyond aliasing in image resampling", MICCAI 2015
288
        """
289
        k = downsampling_factor
290
        variance = (k ** 2 - 1 ** 2) * (2 * np.sqrt(2 * np.log(2))) ** (-2)
291
        sigma = spacing * np.sqrt(variance)
292
        return sigma
293