Passed
Pull Request — master (#133)
by Fernando
01:29
created

Resample.check_affine_key_presence()   A

Complexity

Conditions 3

Size

Total Lines 10
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 8
nop 2
dl 0
loc 10
rs 10
c 0
b 0
f 0
1
from numbers import Number
2
from typing import Union, Tuple, Optional
3
4
import torch
0 ignored issues
show
introduced by
Unable to import 'torch'
Loading history...
5
import numpy as np
0 ignored issues
show
introduced by
Unable to import 'numpy'
Loading history...
6
import nibabel as nib
0 ignored issues
show
introduced by
Unable to import 'nibabel'
Loading history...
7
from nibabel.processing import resample_to_output, resample_from_to
0 ignored issues
show
introduced by
Unable to import 'nibabel.processing'
Loading history...
8
9
from ....data.subject import Subject
10
from ....torchio import LABEL, DATA, AFFINE, TYPE
11
from ... import Interpolation
12
from ... import Transform
13
14
15
TypeSpacing = Union[float, Tuple[float, float, float]]
16
17
18
class Resample(Transform):
19
    """Change voxel spacing by resampling.
20
21
    Args:
22
        target: Tuple :math:`(s_d, s_h, s_w)`. If only one value
23
            :math:`n` is specified, then :math:`s_d = s_h = s_w = n`.
24
            If a string is given, all images will be resampled using the image
25
            with that name as reference.
26
        pre_affine_name: Name of the *image key* (not subject key) storing an
27
            affine matrix that will be applied to the image header before
28
            resampling. If ``None``, the image is resampled with an identity
29
            transform. See usage in the example below.
30
        image_interpolation: Member of :py:class:`torchio.Interpolation`.
31
            Supported interpolation techniques for resampling are
32
            :py:attr:`torchio.Interpolation.NEAREST`,
33
            :py:attr:`torchio.Interpolation.LINEAR` and
34
            :py:attr:`torchio.Interpolation.BSPLINE`.
35
        p: Probability that this transform will be applied.
36
37
38
    .. note:: Resampling is performed using
39
        :py:meth:`nibabel.processing.resample_to_output` or
40
        :py:meth:`nibabel.processing.resample_from_to`, depending on whether
41
        the target is a spacing or a reference image.
42
43
    Example:
44
        >>> import torchio
45
        >>> from torchio.transforms import Resample
46
        >>> transform = Resample(1)          # resample all images to 1mm iso
47
        >>> transform = Resample((1, 1, 1))  # resample all images to 1mm iso
48
        >>> transform = Resample('t1')       # resample all images to 't1' image space
49
        >>>
50
        >>> # Affine matrices are added to each image
51
        >>> matrix_to_mni = some_4_by_4_array  # e.g. result of registration to MNI space
52
        >>> subject = torchio.Subject(
53
        ...     t1=Image('t1.nii.gz', torchio.INTENSITY, to_mni=matrix_to_mni),
54
        ...     mni=Image('mni_152_lin.nii.gz', torchio.INTENSITY),
55
        ... )
56
        >>> resample = Resample(
57
        ...     'mni',  # this is subject key
58
        ...     affine_name='to_mni',  # this is an image key
59
        ... )
60
        >>> dataset = torchio.ImagesDataset([subject], transform=resample)
61
        >>> sample = dataset[0]  # sample['t1'] is now in MNI space
62
    """
63
    def __init__(
64
            self,
65
            target: Union[TypeSpacing, str],
66
            image_interpolation: Interpolation = Interpolation.LINEAR,
67
            pre_affine_name: Optional[str] = None,
68
            p: float = 1,
69
            ):
70
        super().__init__(p=p)
71
        self.reference_image, self.target_spacing = self.parse_target(target)
72
        self.interpolation_order = self.parse_interpolation(
73
            image_interpolation)
74
        self.affine_name = pre_affine_name
75
76
    def parse_target(self, target: Union[TypeSpacing, str]):
77
        if isinstance(target, str):
78
            reference_image = target
79
            target_spacing = None
80
        else:
81
            reference_image = None
82
            target_spacing = self.parse_spacing(target)
83
        return reference_image, target_spacing
84
85
    @staticmethod
86
    def parse_spacing(spacing: TypeSpacing) -> Tuple[float, float, float]:
87
        if isinstance(spacing, tuple) and len(spacing) == 3:
88
            result = spacing
89
        elif isinstance(spacing, Number):
90
            result = 3 * (spacing,)
91
        else:
92
            message = (
93
                'Target must be a string, a positive number'
94
                f' or a tuple of positive numbers, not {type(spacing)}'
95
            )
96
            raise ValueError(message)
97
        if np.any(np.array(spacing) <= 0):
98
            raise ValueError(f'Spacing must be positive, not "{spacing}"')
99
        return result
100
101
    @staticmethod
102
    def parse_interpolation(interpolation: Interpolation) -> int:
103
        if interpolation == Interpolation.NEAREST:
104
            order = 0
105
        elif interpolation == Interpolation.LINEAR:
106
            order = 1
107
        elif interpolation == Interpolation.BSPLINE:
108
            order = 3
109
        else:
110
            message = f'Interpolation not implemented yet: {interpolation}'
111
            raise NotImplementedError(message)
112
        return order
113
114
    @staticmethod
115
    def check_affine(affine_name: str, image_dict: dict):
116
        if not isinstance(affine_name, str):
117
            message = (
118
                'Affine name argument must be a string,'
119
                f' not {type(affine_name)}'
120
            )
121
            raise TypeError(message)
122
        if affine_name in image_dict:
123
            matrix = image_dict[affine_name]
124
            if not isinstance(matrix, np.ndarray):
125
                message = (
126
                    'The affine matrix must be a NumPy array,'
127
                    f' not {type(matrix)}'
128
                )
129
                raise TypeError(message)
130
            if matrix.shape != (4, 4):
131
                message = (
132
                    'The affine matrix shape must be (4, 4),'
133
                    f' not {matrix.shape}'
134
                )
135
                raise ValueError(message)
136
137
    @staticmethod
138
    def check_affine_key_presence(affine_name: str, sample: Subject):
139
        for image_dict in sample.get_images(intensity_only=False):
140
            if affine_name in image_dict:
141
                return
142
        message = (
143
            f'An affine name was given ("{affine_name}"), but it was not found'
144
            ' in any image in the sample'
145
        )
146
        raise ValueError(message)
147
148
    def apply_transform(self, sample: Subject) -> dict:
149
        use_reference = self.reference_image is not None
150
        use_pre_affine = self.affine_name is not None
151
        if use_pre_affine:
152
            self.check_affine_key_presence(self.affine_name, sample)
153
        images_dict = sample.get_images_dict(intensity_only=False).items()
154
        for image_name, image_dict in images_dict:
155
            # Do not resample the reference image if there is one
156
            if use_reference and image_name == self.reference_image:
157
                continue
158
159
            # Choose interpolator
160
            if image_dict[TYPE] == LABEL:
161
                interpolation_order = 0  # nearest neighbor
162
            else:
163
                interpolation_order = self.interpolation_order
164
165
            # Apply given affine matrix if found in image
166
            if use_pre_affine and self.affine_name in image_dict:
167
                self.check_affine(self.affine_name, image_dict)
168
                matrix = image_dict[self.affine_name]
169
                image_dict[AFFINE] = matrix @ image_dict[AFFINE]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable AFFINE does not seem to be defined.
Loading history...
170
171
            # Resample
172
            args = image_dict[DATA], image_dict[AFFINE], interpolation_order
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
173
            if use_reference:
174
                try:
175
                    ref_image_dict = sample[self.reference_image]
176
                except KeyError as error:
177
                    message = (
178
                        f'Reference name "{self.reference_image}"'
179
                        ' not found in sample'
180
                    )
181
                    raise ValueError(message) from error
182
                reference = ref_image_dict[DATA], ref_image_dict[AFFINE]
183
                kwargs = dict(reference=reference)
184
            else:
185
                kwargs = dict(target_spacing=self.target_spacing)
186
            image_dict[DATA], image_dict[AFFINE] = self.apply_resample(
187
                *args,
188
                **kwargs,
189
            )
190
        return sample
191
192
    @staticmethod
193
    def apply_resample(
194
            tensor: torch.Tensor,
195
            affine: np.ndarray,
196
            interpolation_order: int,
197
            target_spacing: Optional[Tuple[float, float, float]] = None,
198
            reference: Optional[Tuple[torch.Tensor, np.ndarray]] = None,
199
            ) -> Tuple[torch.Tensor, np.ndarray]:
200
        array = tensor.numpy()[0]
201
        if reference is None:
202
            nii = resample_to_output(
203
                nib.Nifti1Image(array, affine),
204
                voxel_sizes=target_spacing,
205
                order=interpolation_order,
206
            )
207
        else:
208
            reference_tensor, reference_affine = reference
209
            reference_array = reference_tensor.numpy()[0]
210
            nii = resample_from_to(
211
                nib.Nifti1Image(array, affine),
212
                nib.Nifti1Image(reference_array, reference_affine),
213
                order=interpolation_order,
214
            )
215
        tensor = torch.from_numpy(nii.get_fdata(dtype=np.float32))
216
        tensor = tensor.unsqueeze(dim=0)
217
        return tensor, nii.affine
218