Passed
Pull Request — master (#246)
by Fernando
01:33
created

torchio.transforms.preprocessing.spatial.resample   A

Complexity

Total Complexity 38

Size/Duplication

Total Lines 249
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 157
dl 0
loc 249
rs 9.36
c 0
b 0
f 0
wmc 38

9 Methods

Rating   Name   Duplication   Size   Complexity  
A Resample.check_affine_key_presence() 0 10 3
A Resample.check_affine() 0 22 5
A Resample.parse_spacing() 0 15 5
D Resample.apply_transform() 0 48 12
A Resample.__init__() 0 12 1
A Resample.parse_target() 0 16 3
A Resample.parse_interpolation() 0 13 4
A Resample.get_sigma() 0 11 1
A Resample.apply_resample() 0 31 4
1
from pathlib import Path
2
from numbers import Number
3
from typing import Union, Tuple, Optional
4
5
import torch
6
import numpy as np
7
import nibabel as nib
8
from nibabel.processing import resample_to_output, resample_from_to
9
10
from ....data.subject import Subject
11
from ....data.image import Image
12
from ....torchio import DATA, AFFINE, TYPE, INTENSITY, TypeData
13
from ... import Transform, Interpolation
14
15
16
17
TypeSpacing = Union[float, Tuple[float, float, float]]
18
TypeTarget = Tuple[
19
    Optional[Union[Image, str]],
20
    Optional[Tuple[float, float, float]],
21
]
22
23
24
class Resample(Transform):
25
    """Change voxel spacing by resampling.
26
27
    Args:
28
        target: Tuple :math:`(s_d, s_h, s_w)`. If only one value
29
            :math:`n` is specified, then :math:`s_d = s_h = s_w = n`.
30
            If a string or :py:class:`~pathlib.Path` is given,
31
            all images will be resampled using the image
32
            with that name as reference or found at the path.
33
        pre_affine_name: Name of the *image key* (not subject key) storing an
34
            affine matrix that will be applied to the image header before
35
            resampling. If ``None``, the image is resampled with an identity
36
            transform. See usage in the example below.
37
        image_interpolation: String that defines the interpolation technique.
38
            Supported interpolation techniques for resampling
39
            are ``'nearest'``, ``'linear'`` and ``'bspline'``.
40
            Using a member of :py:class:`torchio.Interpolation` is still
41
            supported for backward compatibility,
42
            but will be removed in a future version.
43
        p: Probability that this transform will be applied.
44
45
    .. note:: Resampling is performed using
46
        :py:meth:`nibabel.processing.resample_to_output` or
47
        :py:meth:`nibabel.processing.resample_from_to`, depending on whether
48
        the target is a spacing or a reference image.
49
50
    Example:
51
        >>> import torchio
52
        >>> from torchio import Resample
53
        >>> from torchio.datasets import Colin27, FPG
54
        >>> transform = Resample(1)                     # resample all images to 1mm iso
55
        >>> transform = Resample((2, 2, 2))             # resample all images to 2mm iso
56
        >>> transform = Resample('t1')                  # resample all images to 't1' image space
57
        >>> colin = Colin27()  # this images are in the MNI space
58
        >>> fpg = FPG()  # matrices to the MNI space are included here
59
        >>> # Resample all images into the MNI space
60
        >>> transform = Resample(colin.t1.path, pre_affine_name='affine_matrix')
61
        >>> transformed = transform(fpg)  # images in fpg are now in MNI space
62
    """
63
    def __init__(
64
            self,
65
            target: Union[TypeSpacing, str, Path],
66
            image_interpolation: str = 'linear',
67
            pre_affine_name: Optional[str] = None,
68
            p: float = 1,
69
            copy: bool = True,
70
            ):
71
        super().__init__(p=p, copy=copy)
72
        self.reference_image, self.target_spacing = self.parse_target(target)
73
        self.interpolation_order = self.parse_interpolation(image_interpolation)
74
        self.affine_name = pre_affine_name
75
76
    def parse_target(
77
            self,
78
            target: Union[TypeSpacing, str],
79
            ) -> TypeTarget:
80
        if isinstance(target, (str, Path)):
81
            if Path(target).is_file():
82
                path = target
83
                image = Image(path)
84
                reference_image = image.data, image.affine
85
            else:
86
                reference_image = target
87
            target_spacing = None
88
        else:
89
            reference_image = None
90
            target_spacing = self.parse_spacing(target)
91
        return reference_image, target_spacing
92
93
    @staticmethod
94
    def parse_spacing(spacing: TypeSpacing) -> Tuple[float, float, float]:
95
        if isinstance(spacing, tuple) and len(spacing) == 3:
96
            result = spacing
97
        elif isinstance(spacing, Number):
98
            result = 3 * (spacing,)
99
        else:
100
            message = (
101
                'Target must be a string, a positive number'
102
                f' or a tuple of positive numbers, not {type(spacing)}'
103
            )
104
            raise ValueError(message)
105
        if np.any(np.array(spacing) <= 0):
106
            raise ValueError(f'Spacing must be positive, not "{spacing}"')
107
        return result
108
109
    def parse_interpolation(self, interpolation: str) -> int:
110
        interpolation = super().parse_interpolation(interpolation)
111
112
        if interpolation in (Interpolation.NEAREST, 'nearest'):
113
            order = 0
114
        elif interpolation in (Interpolation.LINEAR, 'linear'):
115
            order = 1
116
        elif interpolation in (Interpolation.BSPLINE, 'bspline'):
117
            order = 3
118
        else:
119
            message = f'Interpolation not implemented yet: {interpolation}'
120
            raise NotImplementedError(message)
121
        return order
122
123
    @staticmethod
124
    def check_affine(affine_name: str, image_dict: dict):
125
        if not isinstance(affine_name, str):
126
            message = (
127
                'Affine name argument must be a string,'
128
                f' not {type(affine_name)}'
129
            )
130
            raise TypeError(message)
131
        if affine_name in image_dict:
132
            matrix = image_dict[affine_name]
133
            if not isinstance(matrix, (np.ndarray, torch.Tensor)):
134
                message = (
135
                    'The affine matrix must be a NumPy array or PyTorch tensor,'
136
                    f' not {type(matrix)}'
137
                )
138
                raise TypeError(message)
139
            if matrix.shape != (4, 4):
140
                message = (
141
                    'The affine matrix shape must be (4, 4),'
142
                    f' not {matrix.shape}'
143
                )
144
                raise ValueError(message)
145
146
    @staticmethod
147
    def check_affine_key_presence(affine_name: str, sample: Subject):
148
        for image_dict in sample.get_images(intensity_only=False):
149
            if affine_name in image_dict:
150
                return
151
        message = (
152
            f'An affine name was given ("{affine_name}"), but it was not found'
153
            ' in any image in the sample'
154
        )
155
        raise ValueError(message)
156
157
    def apply_transform(self, sample: Subject) -> dict:
158
        use_reference = self.reference_image is not None
159
        use_pre_affine = self.affine_name is not None
160
        if use_pre_affine:
161
            self.check_affine_key_presence(self.affine_name, sample)
162
        images_dict = sample.get_images_dict(intensity_only=False).items()
163
        for image_name, image_dict in images_dict:
164
            # Do not resample the reference image if there is one
165
            if use_reference and image_name == self.reference_image:
166
                continue
167
168
            # Choose interpolator
169
            if image_dict[TYPE] != INTENSITY:
170
                interpolation_order = 0  # nearest neighbor
171
            else:
172
                interpolation_order = self.interpolation_order
173
174
            # Apply given affine matrix if found in image
175
            if use_pre_affine and self.affine_name in image_dict:
176
                self.check_affine(self.affine_name, image_dict)
177
                matrix = image_dict[self.affine_name]
178
                if isinstance(matrix, torch.Tensor):
179
                    matrix = matrix.numpy()
180
                image_dict[AFFINE] = matrix @ image_dict[AFFINE]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable AFFINE does not seem to be defined.
Loading history...
181
182
            # Resample
183
            args = image_dict[DATA], image_dict[AFFINE], interpolation_order
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
184
            if use_reference:
185
                if isinstance(self.reference_image, str):
186
                    try:
187
                        ref_image_dict = sample[self.reference_image]
188
                    except KeyError as error:
189
                        message = (
190
                            f'Reference name "{self.reference_image}"'
191
                            ' not found in sample'
192
                        )
193
                        raise ValueError(message) from error
194
                    reference = ref_image_dict[DATA], ref_image_dict[AFFINE]
195
                else:
196
                    reference = self.reference_image
197
                kwargs = dict(reference=reference)
198
            else:
199
                kwargs = dict(target_spacing=self.target_spacing)
200
            image_dict[DATA], image_dict[AFFINE] = self.apply_resample(
201
                *args,
202
                **kwargs,
203
            )
204
        return sample
205
206
    @staticmethod
207
    def apply_resample(
208
            tensor: torch.Tensor,  # (C, D, H, W)
209
            affine: np.ndarray,
210
            interpolation_order: int,
211
            target_spacing: Optional[Tuple[float, float, float]] = None,
212
            reference: Optional[Tuple[torch.Tensor, np.ndarray]] = None,
213
            ) -> Tuple[torch.Tensor, np.ndarray]:
214
        array = tensor.numpy()
215
        niis = []
216
        arrays_resampled = []
217
        if reference is None:
218
            for channel in array:
219
                nii = resample_to_output(
220
                    nib.Nifti1Image(channel, affine),
221
                    voxel_sizes=target_spacing,
222
                    order=interpolation_order,
223
                )
224
                arrays_resampled.append(nii.get_fdata(dtype=np.float32))
225
        else:
226
            reference_tensor, reference_affine = reference
227
            reference_array = reference_tensor.numpy()[0]
228
            for channel in array:
229
                nii = resample_from_to(
230
                    nib.Nifti1Image(channel, affine),
231
                    nib.Nifti1Image(reference_array, reference_affine),
232
                    order=interpolation_order,
233
                )
234
                arrays_resampled.append(nii.get_fdata(dtype=np.float32))
235
        tensor = torch.Tensor(arrays_resampled)
236
        return tensor, nii.affine
0 ignored issues
show
introduced by
The variable nii does not seem to be defined for all execution paths.
Loading history...
237
238
    @staticmethod
239
    def get_sigma(downsampling_factor, spacing):
240
        """Compute optimal standard deviation for Gaussian kernel.
241
242
        From Cardoso et al., "Scale factor point spread function matching:
243
        beyond aliasing in image resampling", MICCAI 2015
244
        """
245
        k = downsampling_factor
246
        variance = (k ** 2 - 1 ** 2) * (2 * np.sqrt(2 * np.log(2))) ** (-2)
247
        sigma = spacing * np.sqrt(variance)
248
        return sigma
249