Passed
Pull Request — master (#272)
by Fernando
01:31
created

Resample.get_reference_image()   A

Complexity

Conditions 1

Size

Total Lines 18
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 17
nop 2
dl 0
loc 18
rs 9.55
c 0
b 0
f 0
1
from pathlib import Path
2
from numbers import Number
3
from typing import Union, Tuple, Optional, List
4
5
import torch
6
import numpy as np
7
import nibabel as nib
8
import SimpleITK as sitk
9
from nibabel.processing import resample_to_output, resample_from_to
10
11
from ....data.subject import Subject
12
from ....data.image import Image, ScalarImage
13
from ....torchio import DATA, AFFINE, TYPE, INTENSITY, TypeData, TypeTripletFloat
14
from ....utils import sitk_to_nib
15
from ... import SpatialTransform
16
from ... import Interpolation, get_sitk_interpolator
17
18
19
20
TypeSpacing = Union[float, Tuple[float, float, float]]
21
TypeTarget = Tuple[
22
    Optional[Union[Image, str]],
23
    Optional[Tuple[float, float, float]],
24
]
25
26
27
class Resample(SpatialTransform):
28
    """Change voxel spacing by resampling.
29
30
    Args:
31
        target: Tuple :math:`(s_h, s_w, s_d)`. If only one value
32
            :math:`n` is specified, then :math:`s_h = s_w = s_d = n`.
33
            If a string or :py:class:`~pathlib.Path` is given,
34
            all images will be resampled using the image
35
            with that name as reference or found at the path.
36
        pre_affine_name: Name of the *image key* (not subject key) storing an
37
            affine matrix that will be applied to the image header before
38
            resampling. If ``None``, the image is resampled with an identity
39
            transform. See usage in the example below.
40
        image_interpolation: String that defines the interpolation technique.
41
            Supported interpolation techniques for resampling
42
            are ``'nearest'``, ``'linear'`` and ``'bspline'``.
43
            Using a member of :py:class:`torchio.Interpolation` is still
44
            supported for backward compatibility,
45
            but will be removed in a future version.
46
        p: Probability that this transform will be applied.
47
        keys: See :py:class:`~torchio.transforms.Transform`.
48
49
    .. note:: Resampling is performed using
50
        :py:meth:`nibabel.processing.resample_to_output` or
51
        :py:meth:`nibabel.processing.resample_from_to`, depending on whether
52
        the target is a spacing or a reference image.
53
54
    Example:
55
        >>> import torchio
56
        >>> from torchio import Resample
57
        >>> from torchio.datasets import Colin27, FPG
58
        >>> transform = Resample(1)                     # resample all images to 1mm iso
59
        >>> transform = Resample((2, 2, 2))             # resample all images to 2mm iso
60
        >>> transform = Resample('t1')                  # resample all images to 't1' image space
61
        >>> colin = Colin27()  # this images are in the MNI space
62
        >>> fpg = FPG()  # matrices to the MNI space are included here
63
        >>> # Resample all images into the MNI space
64
        >>> transform = Resample(colin.t1.path, pre_affine_name='affine_matrix')
65
        >>> transformed = transform(fpg)  # images in fpg are now in MNI space
66
    """
67
    def __init__(
68
            self,
69
            target: Union[TypeSpacing, str, Path],
70
            image_interpolation: str = 'linear',
71
            pre_affine_name: Optional[str] = None,
72
            p: float = 1,
73
            keys: Optional[List[str]] = None,
74
            ):
75
        super().__init__(p=p, keys=keys)
76
        self.reference_image, self.target_spacing = self.parse_target(target)
77
        self.interpolation = self.parse_interpolation(image_interpolation)
78
        self.affine_name = pre_affine_name
79
80
    def parse_target(
81
            self,
82
            target: Union[TypeSpacing, str],
83
            ) -> TypeTarget:
84
        """
85
        If target is an existing path, return a torchio.ScalarImage
86
        If it does not exist, return the string
87
        If it is not a Path or string, return None
88
        """
89
        if isinstance(target, (str, Path)):
90
            if Path(target).is_file():
91
                path = target
92
                reference_image = ScalarImage(path)
93
            else:
94
                reference_image = target
95
            target_spacing = None
96
        else:
97
            reference_image = None
98
            target_spacing = self.parse_spacing(target)
99
        return reference_image, target_spacing
100
101
    @staticmethod
102
    def parse_spacing(spacing: TypeSpacing) -> Tuple[float, float, float]:
103
        if isinstance(spacing, tuple) and len(spacing) == 3:
104
            result = spacing
105
        elif isinstance(spacing, Number):
106
            result = 3 * (spacing,)
107
        else:
108
            message = (
109
                'Target must be a string, a positive number'
110
                f' or a tuple of positive numbers, not {type(spacing)}'
111
            )
112
            raise ValueError(message)
113
        if np.any(np.array(spacing) <= 0):
114
            raise ValueError(f'Spacing must be positive, not "{spacing}"')
115
        return result
116
117
    @staticmethod
118
    def check_affine(affine_name: str, image_dict: dict):
119
        if not isinstance(affine_name, str):
120
            message = (
121
                'Affine name argument must be a string,'
122
                f' not {type(affine_name)}'
123
            )
124
            raise TypeError(message)
125
        if affine_name in image_dict:
126
            matrix = image_dict[affine_name]
127
            if not isinstance(matrix, (np.ndarray, torch.Tensor)):
128
                message = (
129
                    'The affine matrix must be a NumPy array or PyTorch tensor,'
130
                    f' not {type(matrix)}'
131
                )
132
                raise TypeError(message)
133
            if matrix.shape != (4, 4):
134
                message = (
135
                    'The affine matrix shape must be (4, 4),'
136
                    f' not {matrix.shape}'
137
                )
138
                raise ValueError(message)
139
140
    @staticmethod
141
    def check_affine_key_presence(affine_name: str, sample: Subject):
142
        for image_dict in sample.get_images(intensity_only=False):
143
            if affine_name in image_dict:
144
                return
145
        message = (
146
            f'An affine name was given ("{affine_name}"), but it was not found'
147
            ' in any image in the sample'
148
        )
149
        raise ValueError(message)
150
151
    def apply_transform(self, sample: Subject) -> dict:
152
        use_pre_affine = self.affine_name is not None
153
        if use_pre_affine:
154
            self.check_affine_key_presence(self.affine_name, sample)
155
        images_dict = self.get_images_dict(sample).items()
156
        for image_name, image in images_dict:
157
            # Do not resample the reference image if there is one
158
            if image is self.reference_image:
159
                continue
160
161
            # Choose interpolation
162
            if image[TYPE] != INTENSITY:
163
                interpolation = Interpolation.NEAREST
164
            else:
165
                interpolation = self.interpolation
166
            interpolator = get_sitk_interpolator(interpolation)
167
168
            # Apply given affine matrix if found in image
169
            if use_pre_affine and self.affine_name in image:
170
                self.check_affine(self.affine_name, image)
171
                matrix = image[self.affine_name]
172
                if isinstance(matrix, torch.Tensor):
173
                    matrix = matrix.numpy()
174
                image[AFFINE] = matrix @ image[AFFINE]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable AFFINE does not seem to be defined.
Loading history...
175
176
            floating_itk = image.as_sitk(force_3d=True)
177
178
            # Resample
179
            if isinstance(self.reference_image, str):
180
                try:
181
                    reference_image_sitk = sample[self.reference_image].as_sitk()
182
                except KeyError as error:
183
                    message = (
184
                        f'Reference name "{self.reference_image}"'
185
                        ' not found in sample'
186
                    )
187
                    raise ValueError(message) from error
188
            elif isinstance(self.reference_image, ScalarImage):
189
                reference_image_sitk = self.reference_image.as_sitk()
190
            elif self.reference_image is None:  # target is a spacing
191
                reference_image_sitk = self.get_reference_image(
192
                    floating_itk,
193
                    self.target_spacing,
194
                )
195
196
            resampler = sitk.ResampleImageFilter()
197
            resampler.SetInterpolator(interpolator)
198
            resampler.SetReferenceImage(reference_image_sitk)
0 ignored issues
show
introduced by
The variable reference_image_sitk does not seem to be defined for all execution paths.
Loading history...
199
            resampled = resampler.Execute(floating_itk)
200
201
            image[DATA], image[AFFINE] = sitk_to_nib(resampled)
202
        return sample
203
204
    @staticmethod
205
    def get_reference_image(
206
            image: sitk.Image,
207
            spacing: TypeTripletFloat,
208
            ) -> sitk.Image:
209
        old_spacing = np.array(image.GetSpacing())
210
        new_spacing = np.array(spacing)
211
        old_size = np.array(image.GetSize())
212
        new_size = old_size * old_spacing / new_spacing
213
        new_size = np.ceil(new_size).astype(np.uint16)
214
        new_origin_index = 0.5 * (new_spacing / old_spacing - 1)
215
        new_origin_lps = image.TransformContinuousIndexToPhysicalPoint(
216
            new_origin_index)
217
        reference = sitk.Image(*new_size.tolist(), sitk.sitkFloat32)
218
        reference.SetDirection(image.GetDirection())
219
        reference.SetSpacing(new_spacing.tolist())
220
        reference.SetOrigin(new_origin_lps)
221
        return reference
222
223
    @staticmethod
224
    def get_sigma(downsampling_factor, spacing):
225
        """Compute optimal standard deviation for Gaussian kernel.
226
227
        From Cardoso et al., "Scale factor point spread function matching:
228
        beyond aliasing in image resampling", MICCAI 2015
229
        """
230
        k = downsampling_factor
231
        variance = (k ** 2 - 1 ** 2) * (2 * np.sqrt(2 * np.log(2))) ** (-2)
232
        sigma = spacing * np.sqrt(variance)
233
        return sigma
234