Passed
Pull Request — master (#121)
by Fernando
01:28
created

Resample.check_reference_image()   A

Complexity

Conditions 3

Size

Total Lines 8
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 8
nop 2
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
from numbers import Number
2
from typing import Union, Tuple, Optional
3
import torch
0 ignored issues
show
introduced by
Unable to import 'torch'
Loading history...
4
import numpy as np
0 ignored issues
show
introduced by
Unable to import 'numpy'
Loading history...
5
import nibabel as nib
0 ignored issues
show
introduced by
Unable to import 'nibabel'
Loading history...
6
from nibabel.processing import resample_to_output, resample_from_to
0 ignored issues
show
introduced by
Unable to import 'nibabel.processing'
Loading history...
7
from ....utils import is_image_dict
8
from ....torchio import LABEL, DATA, AFFINE, TYPE
9
from ... import Interpolation
10
from ... import Transform
11
12
13
TypeSpacing = Union[float, Tuple[float, float, float]]
14
15
16
class Resample(Transform):
17
    """Change voxel spacing by resampling.
18
19
    Args:
20
        target: Tuple :math:`(s_d, s_h, s_w)`. If only one value
21
            :math:`n` is specified, then :math:`s_d = s_h = s_w = n`.
22
            If a string is given, all images will be resampled using the image
23
            with that name as reference.
24
        antialiasing: (Not implemented yet).
25
        image_interpolation: Member of :py:class:`torchio.Interpolation`.
26
            Supported interpolation techniques for resampling are
27
            :py:attr:`torchio.Interpolation.NEAREST`,
28
            :py:attr:`torchio.Interpolation.LINEAR` and
29
            :py:attr:`torchio.Interpolation.BSPLINE`.
30
31
    .. note:: Resampling is performed using
32
        :py:meth:`nibabel.processing.resample_to_output` or
33
        :py:meth:`nibabel.processing.resample_from_to`, depending on whether
34
        the target is a spacing or a reference image.
35
36
    Example:
37
        >>> from torchio.transforms import Resample
38
        >>> transform = Resample(1)          # resample all images to 1mm iso
39
        >>> transform = Resample((1, 1, 1))  # resample all images to 1mm iso
40
        >>> transform = Resample('t1')       # resample all images to 't1' image space
41
42
    """
43
    def __init__(
44
            self,
45
            target: Union[TypeSpacing, str],
46
            antialiasing: bool = True,
47
            image_interpolation: Interpolation = Interpolation.LINEAR,
48
            ):
49
        super().__init__()
50
        self.target_spacing: Tuple[float, float, float]
51
        self.reference_image: str
52
        self.parse_target(target)
53
        self.antialiasing = antialiasing
54
        self.interpolation_order = self.parse_interpolation(
55
            image_interpolation)
56
57
    def parse_target(self, target: Union[TypeSpacing, str]):
58
        if isinstance(target, str):
59
            self.reference_image = target
60
            self.target_spacing = None
61
        else:
62
            self.reference_image = None
63
            self.target_spacing = self.parse_spacing(target)
64
65
    @staticmethod
66
    def parse_spacing(spacing: TypeSpacing) -> Tuple[float, float, float]:
67
        if isinstance(spacing, tuple) and len(spacing) == 3:
68
            result = spacing
69
        elif isinstance(spacing, Number):
70
            result = 3 * (spacing,)
71
        else:
72
            message = (
73
                'Target must be a string, a positive number'
74
                f' or a tuple of positive numbers, not {type(spacing)}'
75
            )
76
            raise ValueError(message)
77
        if np.any(np.array(spacing) <= 0):
78
            raise ValueError(f'Spacing must be positive, not "{spacing}"')
79
        return result
80
81
    @staticmethod
82
    def parse_interpolation(interpolation: Interpolation) -> int:
83
        if interpolation == Interpolation.NEAREST:
84
            order = 0
85
        elif interpolation == Interpolation.LINEAR:
86
            order = 1
87
        elif interpolation == Interpolation.BSPLINE:
88
            order = 3
89
        else:
90
            message = f'Interpolation not implemented yet: {interpolation}'
91
            raise NotImplementedError(message)
92
        return order
93
94
    @staticmethod
95
    def check_reference_image(reference_image: str, sample: dict):
96
        if not isinstance(reference_image, str):
97
            message = f'reference_image argument should be of type str, type {type(reference_image)} was given'
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (111/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
98
            raise TypeError(message)
99
        if reference_image not in sample.keys():
100
            message = f'reference_image=\'{reference_image}\' not present in sample, only these keys were found: {sample.keys()}'
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (129/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
101
            raise ValueError(message)
102
103
    def apply_transform(self, sample: dict) -> dict:
104
        use_reference = self.reference_image is not None
105
        for image_name, image_dict in sample.items():
106
            # Do not resample the reference image if there is one
107
            if use_reference and image_name == self.reference_image:
108
                continue
109
            if not is_image_dict(image_dict):
110
                continue
111
112
            # Choose interpolator
113
            if image_dict[TYPE] == LABEL:
114
                interpolation_order = 0  # nearest neighbor
115
            else:
116
                interpolation_order = self.interpolation_order
117
118
            # Resample
119
            args = image_dict[DATA], image_dict[AFFINE], interpolation_order
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable AFFINE does not seem to be defined.
Loading history...
120
            if use_reference:
121
                try:
122
                    ref_image_dict = sample[self.reference_image]
123
                except KeyError as error:
124
                    message = (
125
                        f'Reference name "{self.reference_image}"'
126
                        ' not found in sample'
127
                    )
128
                    raise ValueError(message) from error
129
                reference = ref_image_dict[DATA], ref_image_dict[AFFINE]
130
                kwargs = dict(reference=reference)
131
            else:
132
                kwargs = dict(target_spacing=self.target_spacing)
133
            image_dict[DATA], image_dict[AFFINE] = self.apply_resample(
134
                *args,
135
                **kwargs,
136
            )
137
        return sample
138
139
140
    @staticmethod
141
    def apply_resample(
142
            tensor: torch.Tensor,
143
            affine: np.ndarray,
144
            interpolation_order: int,
145
            target_spacing: Optional[Tuple[float, float, float]] = None,
146
            reference: Optional[Tuple[torch.Tensor, np.ndarray]] = None,
147
            ) -> Tuple[torch.Tensor, np.ndarray]:
148
        array = tensor.numpy()[0]
149
        if reference is None:
150
            nii = resample_to_output(
151
                nib.Nifti1Image(array, affine),
152
                voxel_sizes=target_spacing,
153
                order=interpolation_order,
154
            )
155
        else:
156
            reference_tensor, reference_affine = reference
157
            reference_array = reference_tensor.numpy()[0]
158
            nii = resample_from_to(
159
                nib.Nifti1Image(array, affine),
160
                nib.Nifti1Image(reference_array, reference_affine),
161
                order=interpolation_order,
162
            )
163
        tensor = torch.from_numpy(nii.get_fdata(dtype=np.float32))
164
        tensor = tensor.unsqueeze(dim=0)
165
        return tensor, nii.affine
166