Passed
Push — master ( 7b848f...cac223 )
by Fernando
02:39
created

Resample.apply_transform()   C

Complexity

Conditions 11

Size

Total Lines 46
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 11
eloc 33
nop 2
dl 0
loc 46
rs 5.4
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.preprocessing.spatial.resample.Resample.apply_transform() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from numbers import Number
2
from typing import Union, Tuple, Optional
3
from pathlib import Path
4
import warnings
5
6
import torch
7
import numpy as np
8
import nibabel as nib
9
from nibabel.processing import resample_to_output, resample_from_to
10
11
from ....data.subject import Subject
12
from ....data.image import Image
13
from ....torchio import LABEL, DATA, AFFINE, TYPE, INTENSITY
14
from ... import Interpolation
15
from ... import Transform
16
17
18
TypeSpacing = Union[float, Tuple[float, float, float]]
19
TypeTarget = Tuple[
20
    Optional[Union[Image, str]],
21
    Optional[Tuple[float, float, float]],
22
]
23
24
25
class Resample(Transform):
26
    """Change voxel spacing by resampling.
27
28
    Args:
29
        target: Tuple :math:`(s_d, s_h, s_w)`. If only one value
30
            :math:`n` is specified, then :math:`s_d = s_h = s_w = n`.
31
            If a string or :py:class:`~pathlib.Path` is given,
32
            all images will be resampled using the image
33
            with that name as reference or found at the path.
34
        pre_affine_name: Name of the *image key* (not subject key) storing an
35
            affine matrix that will be applied to the image header before
36
            resampling. If ``None``, the image is resampled with an identity
37
            transform. See usage in the example below.
38
        image_interpolation: String that defines the interpolation technique.
39
            Supported interpolation techniques for resampling
40
            are 'nearest','linear' and 'bspline'.
41
            Using a member of :py:class:`torchio.Interpolation` is still
42
            supported for backward compatibility,
43
            but will be removed in a future version.
44
        p: Probability that this transform will be applied.
45
46
47
    .. note:: Resampling is performed using
48
        :py:meth:`nibabel.processing.resample_to_output` or
49
        :py:meth:`nibabel.processing.resample_from_to`, depending on whether
50
        the target is a spacing or a reference image.
51
52
    Example:
53
        >>> import torchio
54
        >>> from torchio.transforms import Resample
55
        >>> from pathlib import Path
56
        >>> transform = Resample(1)                     # resample all images to 1mm iso
57
        >>> transform = Resample((1, 1, 1))             # resample all images to 1mm iso
58
        >>> transform = Resample('t1')                  # resample all images to 't1' image space
59
        >>> transform = Resample('path/to/ref.nii.gz')  # resample all images to space of image at this path
60
        >>>
61
        >>> # Affine matrices are added to each image
62
        >>> matrix_to_mni = some_4_by_4_array  # e.g. result of registration to MNI space
63
        >>> subject = torchio.Subject(
64
        ...     t1=Image('t1.nii.gz', torchio.INTENSITY, to_mni=matrix_to_mni),
65
        ...     mni=Image('mni_152_lin.nii.gz', torchio.INTENSITY),
66
        ... )
67
        >>> resample = Resample(
68
        ...     'mni',  # this is a subject key
69
        ...     affine_name='to_mni',  # this is an image key
70
        ... )
71
        >>> dataset = torchio.ImagesDataset([subject], transform=resample)
72
        >>> sample = dataset[0]  # sample['t1'] is now in MNI space
73
    """
74
    def __init__(
75
            self,
76
            target: Union[TypeSpacing, str, Path],
77
            image_interpolation: str = 'linear',
78
            pre_affine_name: Optional[str] = None,
79
            p: float = 1,
80
            ):
81
        super().__init__(p=p)
82
        self.reference_image, self.target_spacing = self.parse_target(target)
83
        self.interpolation_order = self.parse_interpolation(image_interpolation)
84
        self.affine_name = pre_affine_name
85
86
    def parse_target(
87
            self,
88
            target: Union[TypeSpacing, str],
89
            ) -> TypeTarget:
90
        if isinstance(target, (str, Path)):
91
            if Path(target).is_file():
92
                reference_image = Image(target, INTENSITY).load()
93
            else:
94
                reference_image = target
95
            target_spacing = None
96
        else:
97
            reference_image = None
98
            target_spacing = self.parse_spacing(target)
99
        return reference_image, target_spacing
100
101
    @staticmethod
102
    def parse_spacing(spacing: TypeSpacing) -> Tuple[float, float, float]:
103
        if isinstance(spacing, tuple) and len(spacing) == 3:
104
            result = spacing
105
        elif isinstance(spacing, Number):
106
            result = 3 * (spacing,)
107
        else:
108
            message = (
109
                'Target must be a string, a positive number'
110
                f' or a tuple of positive numbers, not {type(spacing)}'
111
            )
112
            raise ValueError(message)
113
        if np.any(np.array(spacing) <= 0):
114
            raise ValueError(f'Spacing must be positive, not "{spacing}"')
115
        return result
116
117
    def parse_interpolation(self, interpolation: str) -> int:
118
        interpolation = super().parse_interpolation(interpolation)
119
120
        if interpolation in (Interpolation.NEAREST, 'nearest'):
121
            order = 0
122
        elif interpolation in (Interpolation.LINEAR, 'linear'):
123
            order = 1
124
        elif interpolation in (Interpolation.BSPLINE, 'bspline'):
125
            order = 3
126
        else:
127
            message = f'Interpolation not implemented yet: {interpolation}'
128
            raise NotImplementedError(message)
129
        return order
130
131
    @staticmethod
132
    def check_affine(affine_name: str, image_dict: dict):
133
        if not isinstance(affine_name, str):
134
            message = (
135
                'Affine name argument must be a string,'
136
                f' not {type(affine_name)}'
137
            )
138
            raise TypeError(message)
139
        if affine_name in image_dict:
140
            matrix = image_dict[affine_name]
141
            if not isinstance(matrix, np.ndarray):
142
                message = (
143
                    'The affine matrix must be a NumPy array,'
144
                    f' not {type(matrix)}'
145
                )
146
                raise TypeError(message)
147
            if matrix.shape != (4, 4):
148
                message = (
149
                    'The affine matrix shape must be (4, 4),'
150
                    f' not {matrix.shape}'
151
                )
152
                raise ValueError(message)
153
154
    @staticmethod
155
    def check_affine_key_presence(affine_name: str, sample: Subject):
156
        for image_dict in sample.get_images(intensity_only=False):
157
            if affine_name in image_dict:
158
                return
159
        message = (
160
            f'An affine name was given ("{affine_name}"), but it was not found'
161
            ' in any image in the sample'
162
        )
163
        raise ValueError(message)
164
165
    def apply_transform(self, sample: Subject) -> dict:
166
        use_reference = self.reference_image is not None
167
        use_pre_affine = self.affine_name is not None
168
        if use_pre_affine:
169
            self.check_affine_key_presence(self.affine_name, sample)
170
        images_dict = sample.get_images_dict(intensity_only=False).items()
171
        for image_name, image_dict in images_dict:
172
            # Do not resample the reference image if there is one
173
            if use_reference and image_name == self.reference_image:
174
                continue
175
176
            # Choose interpolator
177
            if image_dict[TYPE] == LABEL:
178
                interpolation_order = 0  # nearest neighbor
179
            else:
180
                interpolation_order = self.interpolation_order
181
182
            # Apply given affine matrix if found in image
183
            if use_pre_affine and self.affine_name in image_dict:
184
                self.check_affine(self.affine_name, image_dict)
185
                matrix = image_dict[self.affine_name]
186
                image_dict[AFFINE] = matrix @ image_dict[AFFINE]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable AFFINE does not seem to be defined.
Loading history...
187
188
            # Resample
189
            args = image_dict[DATA], image_dict[AFFINE], interpolation_order
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
190
            if use_reference:
191
                if isinstance(self.reference_image, str):
192
                    try:
193
                        ref_image_dict = sample[self.reference_image]
194
                    except KeyError as error:
195
                        message = (
196
                            f'Reference name "{self.reference_image}"'
197
                            ' not found in sample'
198
                        )
199
                        raise ValueError(message) from error
200
                    reference = ref_image_dict[DATA], ref_image_dict[AFFINE]
201
                else:
202
                    reference = self.reference_image
203
                kwargs = dict(reference=reference)
204
            else:
205
                kwargs = dict(target_spacing=self.target_spacing)
206
            image_dict[DATA], image_dict[AFFINE] = self.apply_resample(
207
                *args,
208
                **kwargs,
209
            )
210
        return sample
211
212
    @staticmethod
213
    def apply_resample(
214
            tensor: torch.Tensor,
215
            affine: np.ndarray,
216
            interpolation_order: int,
217
            target_spacing: Optional[Tuple[float, float, float]] = None,
218
            reference: Optional[Tuple[torch.Tensor, np.ndarray]] = None,
219
            ) -> Tuple[torch.Tensor, np.ndarray]:
220
        array = tensor.numpy()[0]
221
        if reference is None:
222
            nii = resample_to_output(
223
                nib.Nifti1Image(array, affine),
224
                voxel_sizes=target_spacing,
225
                order=interpolation_order,
226
            )
227
        else:
228
            reference_tensor, reference_affine = reference
229
            reference_array = reference_tensor.numpy()[0]
230
            nii = resample_from_to(
231
                nib.Nifti1Image(array, affine),
232
                nib.Nifti1Image(reference_array, reference_affine),
233
                order=interpolation_order,
234
            )
235
        tensor = torch.from_numpy(nii.get_fdata(dtype=np.float32))
236
        tensor = tensor.unsqueeze(dim=0)
237
        return tensor, nii.affine
238