Passed
Pull Request — master (#133)
by Fernando
01:29
created

Resample.apply_transform()   C

Complexity

Conditions 10

Size

Total Lines 44
Code Lines 32

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 10
eloc 32
nop 2
dl 0
loc 44
rs 5.9999
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.preprocessing.spatial.resample.Resample.apply_transform() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from numbers import Number
2
from typing import Union, Tuple, Optional
3
import torch
0 ignored issues
show
introduced by
Unable to import 'torch'
Loading history...
4
import numpy as np
0 ignored issues
show
introduced by
Unable to import 'numpy'
Loading history...
5
import nibabel as nib
0 ignored issues
show
introduced by
Unable to import 'nibabel'
Loading history...
6
from nibabel.processing import resample_to_output, resample_from_to
0 ignored issues
show
introduced by
Unable to import 'nibabel.processing'
Loading history...
7
from ....data.subject import Subject
8
from ....torchio import LABEL, DATA, AFFINE, TYPE
9
from ... import Interpolation
10
from ... import Transform
11
12
13
TypeSpacing = Union[float, Tuple[float, float, float]]
14
15
16
class Resample(Transform):
17
    """Change voxel spacing by resampling.
18
19
    Args:
20
        target: Tuple :math:`(s_d, s_h, s_w)`. If only one value
21
            :math:`n` is specified, then :math:`s_d = s_h = s_w = n`.
22
            If a string is given, all images will be resampled using the image
23
            with that name as reference.
24
        antialiasing: (Not implemented yet).
25
        image_interpolation: Member of :py:class:`torchio.Interpolation`.
26
            Supported interpolation techniques for resampling are
27
            :py:attr:`torchio.Interpolation.NEAREST`,
28
            :py:attr:`torchio.Interpolation.LINEAR` and
29
            :py:attr:`torchio.Interpolation.BSPLINE`.
30
        p: Probability that this transform will be applied.
31
        coregistration: string. If not None, all affines will be multiplied using
32
            the array with that name as reference before resampling, it is
33
            expected that the coregistration matrix is stored as an image attribute.
34
35
    .. note:: Resampling is performed using
36
        :py:meth:`nibabel.processing.resample_to_output` or
37
        :py:meth:`nibabel.processing.resample_from_to`, depending on whether
38
        the target is a spacing or a reference image.
39
40
    Example:
41
        >>> from torchio.transforms import Resample
42
        >>> transform = Resample(1)          # resample all images to 1mm iso
43
        >>> transform = Resample((1, 1, 1))  # resample all images to 1mm iso
44
        >>> transform = Resample('t1')       # resample all images to 't1' image space
45
46
    """
47
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
48
            self,
49
            target: Union[TypeSpacing, str],
50
            antialiasing: bool = True,
51
            image_interpolation: Interpolation = Interpolation.LINEAR,
52
            p: float = 1,
53
            coregistration: str = None,
54
            ):
55
        super().__init__(p=p)
56
        self.target_spacing: Tuple[float, float, float]
57
        self.reference_image: str
58
        self.parse_target(target)
59
        self.antialiasing = antialiasing
60
        self.interpolation_order = self.parse_interpolation(
61
            image_interpolation)
62
        self.coregistration = coregistration
63
64
    def parse_target(self, target: Union[TypeSpacing, str]):
65
        if isinstance(target, str):
66
            self.reference_image = target
67
            self.target_spacing = None
68
        else:
69
            self.reference_image = None
70
            self.target_spacing = self.parse_spacing(target)
71
72
    @staticmethod
73
    def parse_spacing(spacing: TypeSpacing) -> Tuple[float, float, float]:
74
        if isinstance(spacing, tuple) and len(spacing) == 3:
75
            result = spacing
76
        elif isinstance(spacing, Number):
77
            result = 3 * (spacing,)
78
        else:
79
            message = (
80
                'Target must be a string, a positive number'
81
                f' or a tuple of positive numbers, not {type(spacing)}'
82
            )
83
            raise ValueError(message)
84
        if np.any(np.array(spacing) <= 0):
85
            raise ValueError(f'Spacing must be positive, not "{spacing}"')
86
        return result
87
88
    @staticmethod
89
    def parse_interpolation(interpolation: Interpolation) -> int:
90
        if interpolation == Interpolation.NEAREST:
91
            order = 0
92
        elif interpolation == Interpolation.LINEAR:
93
            order = 1
94
        elif interpolation == Interpolation.BSPLINE:
95
            order = 3
96
        else:
97
            message = f'Interpolation not implemented yet: {interpolation}'
98
            raise NotImplementedError(message)
99
        return order
100
101
    @staticmethod
102
    def check_reference_image(reference_image: str, sample: Subject):
103
        if not isinstance(reference_image, str):
104
            message = f'reference_image argument should be of type str, type {type(reference_image)} was given'
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (111/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
105
            raise TypeError(message)
106
        if reference_image not in sample.keys():
107
            message = f'reference_image=\'{reference_image}\' not present in sample, only these keys were found: {sample.keys()}'
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (129/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
108
            raise ValueError(message)
109
110
    @staticmethod
111
    def check_coregistration(coregistration: str, image_dict: dict):
112
        if not isinstance(coregistration, str):
113
            message = f'coregistration argument should be of type str, type {type(coregistration)} was given'
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (109/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
114
            raise TypeError(message)
115
        if coregistration in image_dict.keys():
116
            if not isinstance(image_dict[coregistration], np.ndarray):
117
                message = (
118
                    f'coregistration matrix={image_dict[coregistration]} should be of type np.ndarray,'
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (103/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
119
                    f'type {type(image_dict[coregistration])} was found'
120
                )
121
                raise TypeError(message)
122
            if image_dict[coregistration].shape != (4, 4):
123
                message = (
124
                    f'coregistration matrix={image_dict[coregistration]} should be of shape (4, 4),'
125
                    f'shape {image_dict[coregistration].shape} was found'
126
                )
127
                raise ValueError(message)
128
129
    @staticmethod
130
    def check_coregistration_key_presence(coregistration: str, images: iter):
131
        for image_dict in images:
132
            if coregistration in image_dict:
133
                return
134
        raise ValueError(f'coregistration key "{coregistration}" should be present in at least one of the sample\'s image ')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (124/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
135
136
    def apply_transform(self, sample: Subject) -> dict:
137
        use_reference = self.reference_image is not None
138
        use_coregistration = self.coregistration is not None
139
        sample_images = sample.get_images_dict(intensity_only=False)
140
        if use_coregistration:
141
            self.check_coregistration_key_presence(self.coregistration, sample_images.values())
142
        for image_name, image_dict in sample_images.items():
143
            # Do not resample the reference image if there is one
144
            if use_reference and image_name == self.reference_image:
145
                continue
146
147
            # Choose interpolator
148
            if image_dict[TYPE] == LABEL:
149
                interpolation_order = 0  # nearest neighbor
150
            else:
151
                interpolation_order = self.interpolation_order
152
153
            # Set coregistration_matrix, coregistration key does not have to be present in every image
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (102/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
154
            coregistration_matrix = np.eye(4)
155
            if use_coregistration:
156
                self.check_coregistration(self.coregistration, image_dict)
157
                if self.coregistration in image_dict.keys():
158
                    coregistration_matrix = image_dict[self.coregistration]
159
160
            # Resample
161
            args = image_dict[DATA], np.dot(coregistration_matrix, image_dict[AFFINE]), interpolation_order
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable AFFINE does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
Coding Style introduced by
This line is too long as per the coding-style (107/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
162
            if use_reference:
163
                try:
164
                    ref_image_dict = sample[self.reference_image]
165
                except KeyError as error:
166
                    message = (
167
                        f'Reference name "{self.reference_image}"'
168
                        ' not found in sample'
169
                    )
170
                    raise ValueError(message) from error
171
                reference = ref_image_dict[DATA], ref_image_dict[AFFINE]
172
                kwargs = dict(reference=reference)
173
            else:
174
                kwargs = dict(target_spacing=self.target_spacing)
175
            image_dict[DATA], image_dict[AFFINE] = self.apply_resample(
176
                *args,
177
                **kwargs,
178
            )
179
        return sample
180
181
    @staticmethod
182
    def apply_resample(
183
            tensor: torch.Tensor,
184
            affine: np.ndarray,
185
            interpolation_order: int,
186
            target_spacing: Optional[Tuple[float, float, float]] = None,
187
            reference: Optional[Tuple[torch.Tensor, np.ndarray]] = None,
188
            ) -> Tuple[torch.Tensor, np.ndarray]:
189
        array = tensor.numpy()[0]
190
        if reference is None:
191
            nii = resample_to_output(
192
                nib.Nifti1Image(array, affine),
193
                voxel_sizes=target_spacing,
194
                order=interpolation_order,
195
            )
196
        else:
197
            reference_tensor, reference_affine = reference
198
            reference_array = reference_tensor.numpy()[0]
199
            nii = resample_from_to(
200
                nib.Nifti1Image(array, affine),
201
                nib.Nifti1Image(reference_array, reference_affine),
202
                order=interpolation_order,
203
            )
204
        tensor = torch.from_numpy(nii.get_fdata(dtype=np.float32))
205
        tensor = tensor.unsqueeze(dim=0)
206
        return tensor, nii.affine
207