Passed
Push — master ( 29643a...47015e )
by Fernando
01:22
created

Resample.get_sigma()   A

Complexity

Conditions 1

Size

Total Lines 11
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 2
dl 0
loc 11
rs 10
c 0
b 0
f 0
1
from pathlib import Path
2
from numbers import Number
3
from typing import Union, Tuple, Optional, List
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
from nibabel.processing import resample_to_output, resample_from_to
9
10
from ....data.subject import Subject
11
from ....data.image import Image, ScalarImage
12
from ....torchio import DATA, AFFINE, TYPE, INTENSITY, TypeData, TypeTripletFloat
13
from ....utils import sitk_to_nib
14
from ... import SpatialTransform
15
from ... import Interpolation, get_sitk_interpolator
16
17
18
TypeSpacing = Union[float, Tuple[float, float, float]]
19
TypeTarget = Tuple[
20
    Optional[Union[Image, str]],
21
    Optional[Tuple[float, float, float]],
22
]
23
24
25
class Resample(SpatialTransform):
26
    """Change voxel spacing by resampling.
27
28
    Args:
29
        target: Tuple :math:`(s_h, s_w, s_d)`. If only one value
30
            :math:`n` is specified, then :math:`s_h = s_w = s_d = n`.
31
            If a string or :py:class:`~pathlib.Path` is given,
32
            all images will be resampled using the image
33
            with that name as reference or found at the path.
34
        pre_affine_name: Name of the *image key* (not subject key) storing an
35
            affine matrix that will be applied to the image header before
36
            resampling. If ``None``, the image is resampled with an identity
37
            transform. See usage in the example below.
38
        image_interpolation: String that defines the interpolation technique.
39
            Supported interpolation techniques for resampling
40
            are ``'nearest'``, ``'linear'`` and ``'bspline'``.
41
            Using a member of :py:class:`torchio.Interpolation` is still
42
            supported for backward compatibility,
43
            but will be removed in a future version.
44
        p: Probability that this transform will be applied.
45
        keys: See :py:class:`~torchio.transforms.Transform`.
46
47
    Example:
48
        >>> import torchio
49
        >>> from torchio import Resample
50
        >>> from torchio.datasets import Colin27, FPG
51
        >>> transform = Resample(1)                     # resample all images to 1mm iso
52
        >>> transform = Resample((2, 2, 2))             # resample all images to 2mm iso
53
        >>> transform = Resample('t1')                  # resample all images to 't1' image space
54
        >>> colin = Colin27()  # this images are in the MNI space
55
        >>> fpg = FPG()  # matrices to the MNI space are included here
56
        >>> # Resample all images into the MNI space
57
        >>> transform = Resample(colin.t1.path, pre_affine_name='affine_matrix')
58
        >>> transformed = transform(fpg)  # images in fpg are now in MNI space
59
    """
60
    def __init__(
61
            self,
62
            target: Union[TypeSpacing, str, Path],
63
            image_interpolation: str = 'linear',
64
            pre_affine_name: Optional[str] = None,
65
            p: float = 1,
66
            keys: Optional[List[str]] = None,
67
            ):
68
        super().__init__(p=p, keys=keys)
69
        self.reference_image, self.target_spacing = self.parse_target(target)
70
        self.interpolation = self.parse_interpolation(image_interpolation)
71
        self.affine_name = pre_affine_name
72
73
    def parse_target(
74
            self,
75
            target: Union[TypeSpacing, str],
76
            ) -> TypeTarget:
77
        """
78
        If target is an existing path, return a torchio.ScalarImage
79
        If it does not exist, return the string
80
        If it is not a Path or string, return None
81
        """
82
        if isinstance(target, (str, Path)):
83
            if Path(target).is_file():
84
                path = target
85
                reference_image = ScalarImage(path)
86
            else:
87
                reference_image = target
88
            target_spacing = None
89
        else:
90
            reference_image = None
91
            target_spacing = self.parse_spacing(target)
92
        return reference_image, target_spacing
93
94
    @staticmethod
95
    def parse_spacing(spacing: TypeSpacing) -> Tuple[float, float, float]:
96
        if isinstance(spacing, tuple) and len(spacing) == 3:
97
            result = spacing
98
        elif isinstance(spacing, Number):
99
            result = 3 * (spacing,)
100
        else:
101
            message = (
102
                'Target must be a string, a positive number'
103
                f' or a tuple of positive numbers, not {type(spacing)}'
104
            )
105
            raise ValueError(message)
106
        if np.any(np.array(spacing) <= 0):
107
            raise ValueError(f'Spacing must be positive, not "{spacing}"')
108
        return result
109
110
    @staticmethod
111
    def check_affine(affine_name: str, image_dict: dict):
112
        if not isinstance(affine_name, str):
113
            message = (
114
                'Affine name argument must be a string,'
115
                f' not {type(affine_name)}'
116
            )
117
            raise TypeError(message)
118
        if affine_name in image_dict:
119
            matrix = image_dict[affine_name]
120
            if not isinstance(matrix, (np.ndarray, torch.Tensor)):
121
                message = (
122
                    'The affine matrix must be a NumPy array or PyTorch tensor,'
123
                    f' not {type(matrix)}'
124
                )
125
                raise TypeError(message)
126
            if matrix.shape != (4, 4):
127
                message = (
128
                    'The affine matrix shape must be (4, 4),'
129
                    f' not {matrix.shape}'
130
                )
131
                raise ValueError(message)
132
133
    @staticmethod
134
    def check_affine_key_presence(affine_name: str, sample: Subject):
135
        for image_dict in sample.get_images(intensity_only=False):
136
            if affine_name in image_dict:
137
                return
138
        message = (
139
            f'An affine name was given ("{affine_name}"), but it was not found'
140
            ' in any image in the sample'
141
        )
142
        raise ValueError(message)
143
144
    def apply_transform(self, sample: Subject) -> dict:
145
        use_pre_affine = self.affine_name is not None
146
        if use_pre_affine:
147
            self.check_affine_key_presence(self.affine_name, sample)
148
        images_dict = self.get_images_dict(sample).items()
149
        for image_name, image in images_dict:
150
            # Do not resample the reference image if there is one
151
            if image is self.reference_image:
152
                continue
153
154
            # Choose interpolation
155
            if image[TYPE] != INTENSITY:
156
                interpolation = Interpolation.NEAREST
157
            else:
158
                interpolation = self.interpolation
159
            interpolator = get_sitk_interpolator(interpolation)
160
161
            # Apply given affine matrix if found in image
162
            if use_pre_affine and self.affine_name in image:
163
                self.check_affine(self.affine_name, image)
164
                matrix = image[self.affine_name]
165
                if isinstance(matrix, torch.Tensor):
166
                    matrix = matrix.numpy()
167
                image[AFFINE] = matrix @ image[AFFINE]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable AFFINE does not seem to be defined.
Loading history...
168
169
            floating_itk = image.as_sitk(force_3d=True)
170
171
            # Resample
172
            if isinstance(self.reference_image, str):
173
                try:
174
                    reference_image_sitk = sample[self.reference_image].as_sitk()
175
                except KeyError as error:
176
                    message = (
177
                        f'Reference name "{self.reference_image}"'
178
                        ' not found in sample'
179
                    )
180
                    raise ValueError(message) from error
181
            elif isinstance(self.reference_image, ScalarImage):
182
                reference_image_sitk = self.reference_image.as_sitk()
183
            elif self.reference_image is None:  # target is a spacing
184
                reference_image_sitk = self.get_reference_image(
185
                    floating_itk,
186
                    self.target_spacing,
187
                )
188
189
            resampler = sitk.ResampleImageFilter()
190
            resampler.SetInterpolator(interpolator)
191
            resampler.SetReferenceImage(reference_image_sitk)
0 ignored issues
show
introduced by
The variable reference_image_sitk does not seem to be defined for all execution paths.
Loading history...
192
            resampled = resampler.Execute(floating_itk)
193
194
            image[DATA], image[AFFINE] = sitk_to_nib(resampled)
195
        return sample
196
197
    @staticmethod
198
    def get_reference_image(
199
            image: sitk.Image,
200
            spacing: TypeTripletFloat,
201
            ) -> sitk.Image:
202
        old_spacing = np.array(image.GetSpacing())
203
        new_spacing = np.array(spacing)
204
        old_size = np.array(image.GetSize())
205
        new_size = old_size * old_spacing / new_spacing
206
        new_size = np.ceil(new_size).astype(np.uint16)
207
        new_origin_index = 0.5 * (new_spacing / old_spacing - 1)
208
        new_origin_lps = image.TransformContinuousIndexToPhysicalPoint(
209
            new_origin_index)
210
        reference = sitk.Image(*new_size.tolist(), sitk.sitkFloat32)
211
        reference.SetDirection(image.GetDirection())
212
        reference.SetSpacing(new_spacing.tolist())
213
        reference.SetOrigin(new_origin_lps)
214
        return reference
215
216
    @staticmethod
217
    def get_sigma(downsampling_factor, spacing):
218
        """Compute optimal standard deviation for Gaussian kernel.
219
220
        From Cardoso et al., "Scale factor point spread function matching:
221
        beyond aliasing in image resampling", MICCAI 2015
222
        """
223
        k = downsampling_factor
224
        variance = (k ** 2 - 1 ** 2) * (2 * np.sqrt(2 * np.log(2))) ** (-2)
225
        sigma = spacing * np.sqrt(variance)
226
        return sigma
227