Passed
Push — master ( c291a8...879ee9 )
by Fernando
59s
created

Transform.parse_sample()   A

Complexity

Conditions 2

Size

Total Lines 9
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 1
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
import numbers
2
import warnings
3
from typing import Union, Tuple
4
from copy import deepcopy
5
from abc import ABC, abstractmethod
6
7
import torch
8
import numpy as np
9
import SimpleITK as sitk
10
11
from .. import TypeData, INTENSITY, DATA, TypeNumber
12
from ..data.image import Image
13
from ..data.subject import Subject
14
from ..data.dataset import ImagesDataset
15
from ..utils import nib_to_sitk, sitk_to_nib
16
from .interpolation import Interpolation
17
18
19
class Transform(ABC):
20
    """Abstract class for all TorchIO transforms.
21
22
    All classes used to transform a sample from an
23
    :py:class:`~torchio.ImagesDataset` should subclass it.
24
    All subclasses should overwrite
25
    :py:meth:`torchio.tranforms.Transform.apply_transform`,
26
    which takes a sample, applies some transformation and returns the result.
27
28
    Args:
29
        p: Probability that this transform will be applied.
30
        copy: Make a deep copy of the input before applying the transform.
31
    """
32
    def __init__(self, p: float = 1, copy: bool = True):
33
        self.probability = self.parse_probability(p)
34
        self.copy = copy
35
36
    def __call__(self, data: Union[Subject, torch.Tensor, np.ndarray]):
37
        """Transform a sample and return the result.
38
39
        Args:
40
            data: Instance of :py:class:`~torchio.Subject`, 4D
41
                :py:class:`torch.Tensor` or 4D NumPy array with dimensions
42
                :math:`(C, D, H, W)`, where :math:`C` is the number of channels
43
                and :math:`D, H, W` are the spatial dimensions. If the input is
44
                a tensor, the affine matrix is an identity and a tensor will be
45
                also returned.
46
        """
47
        if torch.rand(1).item() > self.probability:
48
            return data
49
        if isinstance(data, (np.ndarray, torch.Tensor)):
50
            is_array = isinstance(data, np.ndarray)
51
            is_tensor = True
52
            sample = self.parse_tensor(data)
53
        else:
54
            is_tensor = is_array = False
55
            sample = data
56
        self.parse_sample(sample)
57
58
        if self.copy:
59
            sample = deepcopy(sample)
60
61
        with np.errstate(all='raise'):
62
            transformed = self.apply_transform(sample)
63
64
        if is_tensor:
65
            num_channels = len(data)
66
            images = [
67
                transformed[f'channel_{i}'][DATA]
68
                for i in range(num_channels)
69
            ]
70
            transformed = torch.cat(images)
71
        if is_array:
72
            transformed = transformed.numpy()
73
        return transformed
74
75
    @abstractmethod
76
    def apply_transform(self, sample: Subject):
77
        raise NotImplementedError
78
79
    @staticmethod
80
    def parse_range(
81
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
82
            name: str,
83
            min_constraint: TypeNumber = None,
84
            max_constraint: TypeNumber = None,
85
            type_constraint: type = None,
86
            ) -> Tuple[TypeNumber, TypeNumber]:
87
        r"""Adapted from ``torchvision.transforms.RandomRotation``.
88
89
        Args:
90
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
91
                where :math:`n_{min} \leq n_{max}`.
92
                If a single positive number :math:`n` is provided,
93
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
94
            name: Name of the parameter, so that an informative error message
95
                can be printed.
96
            min_constraint: Minimal value that :math:`n_{min}` can take,
97
                default is None, i.e. there is no minimal value.
98
            max_constraint: Maximal value that :math:`n_{max}` can take,
99
                default is None, i.e. there is no maximal value.
100
            type_constraint: Precise type that :math:`n_{max}` and
101
                :math:`n_{min}` must take.
102
103
        Returns:
104
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
105
106
        Raises:
107
            ValueError: if :attr:`nums_range` is negative
108
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
109
            ValueError: if :math:`n_{max} \lt n_{min}`
110
            ValueError: if :attr:`min_constraint` is not None and
111
                :math:`n_{min}` is smaller than :attr:`min_constraint`
112
            ValueError: if :attr:`max_constraint` is not None and
113
                :math:`n_{max}` is greater than :attr:`max_constraint`
114
            ValueError: if :attr:`type_constraint` is not None and
115
                :math:`n_{max}` and :math:`n_{max}` are not of type
116
                :attr:`type_constraint`.
117
        """
118
        if isinstance(nums_range, numbers.Number):
119
            if nums_range < 0:
120
                raise ValueError(
121
                    f'If {name} is a single number,'
122
                    f' it must be positive, not {nums_range}')
123
            if min_constraint is not None and nums_range < min_constraint:
124
                raise ValueError(
125
                    f'If {name} is a single number, it must be greater'
126
                    f'than {min_constraint}, not {nums_range}'
127
                )
128
            if max_constraint is not None and nums_range > max_constraint:
129
                raise ValueError(
130
                    f'If {name} is a single number, it must be smaller'
131
                    f'than {max_constraint}, not {nums_range}'
132
                )
133
            if type_constraint is not None and \
134
                    not isinstance(nums_range, type_constraint):
135
                raise ValueError(
136
                    f'If {name} is a single number, it must be of'
137
                    f'type {type_constraint}, not {nums_range}'
138
                )
139
            min_range = -nums_range if min_constraint is None else nums_range
140
            return (min_range, nums_range)
141
142
        try:
143
            min_degree, max_degree = nums_range
144
        except (TypeError, ValueError):
145
            raise ValueError(
146
                f'If {name} is not a single number, it muste be'
147
                f'a sequence of len 2, not {nums_range}'
148
            )
149
150
        if not isinstance(min_degree, numbers.Number) or \
151
                not isinstance(max_degree, numbers.Number):
152
            message = (
153
                f'{name} values must be numbers, not {nums_range}')
154
            raise ValueError(message)
155
156
        if min_degree > max_degree:
157
            raise ValueError(
158
                f'If {name} is a sequence, the second value must be'
159
                f' equal or greater than the first, not {nums_range}')
160
161
        if min_constraint is not None and min_degree < min_constraint:
162
            raise ValueError(
163
                f'If {name} is a sequence, the first value must be greater'
164
                f'than {min_constraint}, not {min_degree}'
165
            )
166
167
        if max_constraint is not None and max_degree > max_constraint:
168
            raise ValueError(
169
                f'If {name} is a sequence, the second value must be smaller'
170
                f'than {max_constraint}, not {max_degree}'
171
            )
172
173
        if type_constraint is not None:
174
            if not isinstance(min_degree, type_constraint) or \
175
                    not isinstance(max_degree, type_constraint):
176
                raise ValueError(
177
                    f'If {name} is a sequence, its values must be of'
178
                    f'type {type_constraint}, not {nums_range}'
179
                )
180
        return nums_range
181
182
    @staticmethod
183
    def parse_probability(probability: float) -> float:
184
        is_number = isinstance(probability, numbers.Number)
185
        if not (is_number and 0 <= probability <= 1):
186
            message = (
187
                'Probability must be a number in [0, 1],'
188
                f' not {probability}'
189
            )
190
            raise ValueError(message)
191
        return probability
192
193
    @staticmethod
194
    def parse_sample(sample: Subject) -> None:
195
        if not isinstance(sample, Subject):
196
            message = (
197
                'Input to a transform must be a PyTorch tensor or an instance'
198
                ' of torchio.Subject generated by a torchio.ImagesDataset,'
199
                f' not "{type(sample)}"'
200
            )
201
            raise RuntimeError(message)
202
203
    def parse_tensor(self, data: TypeData) -> Subject:
204
        if isinstance(data, np.ndarray):
205
            tensor = torch.from_numpy(data)
206
        else:
207
            tensor = data
208
        tensor = tensor.float()  # does nothing if already float
209
        num_dimensions = tensor.dim()
210
        if num_dimensions != 4:
211
            message = (
212
                'The input tensor must have 4 dimensions (channels, i, j, k),'
213
                f' but has {num_dimensions}: {tensor.shape}'
214
            )
215
            raise RuntimeError(message)
216
        return self._get_subject_from_tensor(tensor)
217
218
    @staticmethod
219
    def parse_interpolation(interpolation: str) -> Interpolation:
220
        if isinstance(interpolation, Interpolation):
221
            message = (
222
                'Interpolation of type torchio.Interpolation'
223
                ' is deprecated, please use a string instead'
224
            )
225
            warnings.warn(message, FutureWarning)
226
        elif isinstance(interpolation, str):
227
            interpolation = interpolation.lower()
228
            supported_values = [key.name.lower() for key in Interpolation]
229
            if interpolation in supported_values:
230
                interpolation = getattr(Interpolation, interpolation.upper())
231
            else:
232
                message = (
233
                    f'Interpolation "{interpolation}" is not among'
234
                    f' the supported values: {supported_values}'
235
                )
236
                raise AttributeError(message)
237
        else:
238
            message = (
239
                'image_interpolation must be a string,'
240
                f' not {type(interpolation)}'
241
            )
242
            raise TypeError(message)
243
        return interpolation
244
245
    @staticmethod
246
    def _get_subject_from_tensor(tensor: torch.Tensor) -> Subject:
247
        subject_dict = {}
248
        for channel_index, channel_tensor in enumerate(tensor):
249
            name = f'channel_{channel_index}'
250
            image = Image(tensor=channel_tensor, type=INTENSITY)
251
            subject_dict[name] = image
252
        subject = Subject(subject_dict)
253
        return subject
254
255
    @staticmethod
256
    def nib_to_sitk(data: TypeData, affine: TypeData):
257
        return nib_to_sitk(data, affine)
258
259
    @staticmethod
260
    def sitk_to_nib(image: sitk.Image):
261
        return sitk_to_nib(image)
262
263
    @property
264
    def name(self):
265
        return self.__class__.__name__
266