Passed
Pull Request — master (#346)
by Fernando
01:46
created

Transform.parse_params()   B

Complexity

Conditions 6

Size

Total Lines 17
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 14
nop 5
dl 0
loc 17
rs 8.6666
c 0
b 0
f 0
1
import copy
2
import numbers
3
import warnings
4
from abc import ABC, abstractmethod
5
from typing import Optional, Union, Tuple, List
6
7
import torch
8
import numpy as np
9
import nibabel as nib
10
import SimpleITK as sitk
11
12
from .. import TypeData, DATA, AFFINE, TypeNumber
13
from ..data.subject import Subject
14
from ..data.image import Image, ScalarImage
15
from ..utils import nib_to_sitk, sitk_to_nib, to_tuple
16
from .interpolation import Interpolation
17
18
19
TypeTransformInput = Union[Subject, torch.Tensor, np.ndarray, dict, sitk.Image]
20
21
22
class Transform(ABC):
23
    """Abstract class for all TorchIO transforms.
24
25
    All subclasses should overwrite
26
    :py:meth:`torchio.tranforms.Transform.apply_transform`,
27
    which takes data, applies some transformation and returns the result.
28
29
    The input can be an instance of
30
    :py:class:`torchio.Subject`,
31
    :py:class:`torchio.Image`,
32
    :py:class:`numpy.ndarray`,
33
    :py:class:`torch.Tensor`,
34
    :py:class:`SimpleITK.image`,
35
    or a Python dictionary.
36
37
    Args:
38
        p: Probability that this transform will be applied.
39
        copy: Make a shallow copy of the input before applying the transform.
40
        keys: Mandatory if the input is a Python dictionary. The transform will
41
            be applied only to the data in each key.
42
    """
43
    def __init__(
44
            self,
45
            p: float = 1,
46
            copy: bool = True,
47
            keys: Optional[List[str]] = None,
48
            ):
49
        self.probability = self.parse_probability(p)
50
        self.copy = copy
51
        self.keys = keys
52
        self.default_image_name = 'default_image_name'
53
54
    def __call__(self, data: TypeTransformInput) -> TypeTransformInput:
55
        """Transform data and return a result of the same type.
56
57
        Args:
58
            data: Instance of :py:class:`~torchio.Subject`, 4D
59
                :py:class:`torch.Tensor` or 4D NumPy array with dimensions
60
                :math:`(C, W, H, D)`, where :math:`C` is the number of channels
61
                and :math:`W, H, D` are the spatial dimensions. If the input is
62
                a tensor, the affine matrix is an identity and a tensor will be
63
                also returned.
64
        """
65
        if torch.rand(1).item() > self.probability:
66
            return data
67
68
        is_tensor = is_array = is_dict = is_image = is_sitk = is_nib = False
69
70
        if isinstance(data, nib.Nifti1Image):
71
            tensor = data.get_fdata(dtype=np.float32)
72
            data = ScalarImage(tensor=tensor, affine=data.affine)
73
            subject = self._get_subject_from_image(data)
74
            is_nib = True
75
        elif isinstance(data, (np.ndarray, torch.Tensor)):
76
            subject = self.parse_tensor(data)
77
            is_array = isinstance(data, np.ndarray)
78
            is_tensor = True
79
        elif isinstance(data, Image):
80
            subject = self._get_subject_from_image(data)
81
            is_image = True
82
        elif isinstance(data, Subject):
83
            subject = data
84
        elif isinstance(data, sitk.Image):
85
            subject = self._get_subject_from_sitk_image(data)
86
            is_sitk = True
87
        elif isinstance(data, dict):  # e.g. Eisen or MONAI dicts
88
            if self.keys is None:
89
                message = (
90
                    'If input is a dictionary, a value for "keys" must be'
91
                    ' specified when instantiating the transform'
92
                )
93
                raise RuntimeError(message)
94
            subject = self._get_subject_from_dict(data, self.keys)
95
            is_dict = True
96
        else:
97
            raise ValueError(f'Input type not recognized: {type(data)}')
98
        self.parse_subject(subject)
99
100
        if self.copy:
101
            subject = copy.copy(subject)
102
103
        with np.errstate(all='raise'):
104
            transformed = self.apply_transform(subject)
105
106
        for image in transformed.get_images(intensity_only=False):
107
            ndim = image[DATA].ndim
108
            assert ndim == 4, f'Output of {self.name} is {ndim}D'
109
110
        if is_tensor or is_sitk:
111
            image = transformed[self.default_image_name]
112
            transformed = image[DATA]
113
            if is_array:
114
                transformed = transformed.numpy()
115
            elif is_sitk:
116
                transformed = nib_to_sitk(image[DATA], image[AFFINE])
117
        elif is_image:
118
            transformed = transformed[self.default_image_name]
119
        elif is_dict:
120
            transformed = dict(transformed)
121
            for key, value in transformed.items():
122
                if isinstance(value, Image):
123
                    transformed[key] = value.data
124
        elif is_nib:
125
            image = transformed[self.default_image_name]
126
            data = image[DATA]
127
            if len(data) > 1:
128
                message = (
129
                    'Multichannel images not supported for input of type'
130
                    ' nibabel.nifti.Nifti1Image'
131
                )
132
                raise RuntimeError(message)
133
            transformed = nib.Nifti1Image(data[0].numpy(), image[AFFINE])
134
        return transformed
135
136
    @abstractmethod
137
    def apply_transform(self, subject: Subject):
138
        raise NotImplementedError
139
140
    @staticmethod
141
    def to_range(n, around):
142
        if around is None:
143
            return 0, n
144
        else:
145
            return around - n, around + n
146
147
    def parse_params(self, params, around, name, **kwargs):
148
        params = to_tuple(params)
149
        if len(params) == 1 or len(params) == 2:  # d or (a, b)
150
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
151
        if len(params) == 3:  # (a, b, c)
152
            items = [self.to_range(n, around) for n in params]
153
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
154
            params = [n for prange in items for n in prange]
155
        if len(params) != 6:
156
            message = (
157
                f'If "{name}" is a sequence, it must have length 2, 3 or 6,'
158
                f' not {len(params)}'
159
            )
160
            raise ValueError(message)
161
        for param_range in zip(params[::2], params[1::2]):
162
            self.parse_range(param_range, name, **kwargs)
163
        return tuple(params)
164
165
    @staticmethod
166
    def parse_range(
167
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
168
            name: str,
169
            min_constraint: TypeNumber = None,
170
            max_constraint: TypeNumber = None,
171
            type_constraint: type = None,
172
            ) -> Tuple[TypeNumber, TypeNumber]:
173
        r"""Adapted from ``torchvision.transforms.RandomRotation``.
174
175
        Args:
176
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
177
                where :math:`n_{min} \leq n_{max}`.
178
                If a single positive number :math:`n` is provided,
179
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
180
            name: Name of the parameter, so that an informative error message
181
                can be printed.
182
            min_constraint: Minimal value that :math:`n_{min}` can take,
183
                default is None, i.e. there is no minimal value.
184
            max_constraint: Maximal value that :math:`n_{max}` can take,
185
                default is None, i.e. there is no maximal value.
186
            type_constraint: Precise type that :math:`n_{max}` and
187
                :math:`n_{min}` must take.
188
189
        Returns:
190
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
191
192
        Raises:
193
            ValueError: if :attr:`nums_range` is negative
194
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
195
            ValueError: if :math:`n_{max} \lt n_{min}`
196
            ValueError: if :attr:`min_constraint` is not None and
197
                :math:`n_{min}` is smaller than :attr:`min_constraint`
198
            ValueError: if :attr:`max_constraint` is not None and
199
                :math:`n_{max}` is greater than :attr:`max_constraint`
200
            ValueError: if :attr:`type_constraint` is not None and
201
                :math:`n_{max}` and :math:`n_{max}` are not of type
202
                :attr:`type_constraint`.
203
        """
204
        if isinstance(nums_range, numbers.Number):  # single number given
205
            if nums_range < 0:
206
                raise ValueError(
207
                    f'If {name} is a single number,'
208
                    f' it must be positive, not {nums_range}')
209
            if min_constraint is not None and nums_range < min_constraint:
210
                raise ValueError(
211
                    f'If {name} is a single number, it must be greater'
212
                    f' than {min_constraint}, not {nums_range}'
213
                )
214
            if max_constraint is not None and nums_range > max_constraint:
215
                raise ValueError(
216
                    f'If {name} is a single number, it must be smaller'
217
                    f' than {max_constraint}, not {nums_range}'
218
                )
219
            if type_constraint is not None:
220
                if not isinstance(nums_range, type_constraint):
221
                    raise ValueError(
222
                        f'If {name} is a single number, it must be of'
223
                        f' type {type_constraint}, not {nums_range}'
224
                    )
225
            min_range = -nums_range if min_constraint is None else nums_range
226
            return (min_range, nums_range)
227
228
        try:
229
            min_value, max_value = nums_range
230
        except (TypeError, ValueError):
231
            raise ValueError(
232
                f'If {name} is not a single number, it must be'
233
                f' a sequence of len 2, not {nums_range}'
234
            )
235
236
        min_is_number = isinstance(min_value, numbers.Number)
237
        max_is_number = isinstance(max_value, numbers.Number)
238
        if not min_is_number or not max_is_number:
239
            message = (
240
                f'{name} values must be numbers, not {nums_range}')
241
            raise ValueError(message)
242
243
        if min_value > max_value:
244
            raise ValueError(
245
                f'If {name} is a sequence, the second value must be'
246
                f' equal or greater than the first, but it is {nums_range}')
247
248
        if min_constraint is not None and min_value < min_constraint:
249
            raise ValueError(
250
                f'If {name} is a sequence, the first value must be greater'
251
                f' than {min_constraint}, but it is {min_value}'
252
            )
253
254
        if max_constraint is not None and max_value > max_constraint:
255
            raise ValueError(
256
                f'If {name} is a sequence, the second value must be smaller'
257
                f' than {max_constraint}, but it is {max_value}'
258
            )
259
260
        if type_constraint is not None:
261
            min_type_ok = isinstance(min_value, type_constraint)
262
            max_type_ok = isinstance(max_value, type_constraint)
263
            if not min_type_ok or not max_type_ok:
264
                raise ValueError(
265
                    f'If "{name}" is a sequence, its values must be of'
266
                    f' type "{type_constraint}", not "{type(nums_range)}"'
267
                )
268
        return nums_range
269
270
    @staticmethod
271
    def parse_probability(probability: float) -> float:
272
        is_number = isinstance(probability, numbers.Number)
273
        if not (is_number and 0 <= probability <= 1):
274
            message = (
275
                'Probability must be a number in [0, 1],'
276
                f' not {probability}'
277
            )
278
            raise ValueError(message)
279
        return probability
280
281
    @staticmethod
282
    def parse_subject(subject: Subject) -> None:
283
        if not isinstance(subject, Subject):
284
            message = (
285
                'Input to a transform must be a tensor or an instance'
286
                f' of torchio.Subject, not "{type(subject)}"'
287
            )
288
            raise RuntimeError(message)
289
290
    def parse_tensor(self, data: TypeData) -> Subject:
291
        if data.ndim != 4:
292
            message = (
293
                'The input must be a 4D tensor with dimensions'
294
                f' (channels, x, y, z) but it has shape {tuple(data.shape)}'
295
            )
296
            raise ValueError(message)
297
        return self._get_subject_from_tensor(data)
298
299
    @staticmethod
300
    def parse_interpolation(interpolation: str) -> Interpolation:
301
        if isinstance(interpolation, Interpolation):
302
            message = (
303
                'Interpolation of type torchio.Interpolation'
304
                ' is deprecated, please use a string instead'
305
            )
306
            warnings.warn(message, FutureWarning)
307
        elif isinstance(interpolation, str):
308
            interpolation = interpolation.lower()
309
            supported_values = [key.name.lower() for key in Interpolation]
310
            if interpolation in supported_values:
311
                interpolation = getattr(Interpolation, interpolation.upper())
312
            else:
313
                message = (
314
                    f'Interpolation "{interpolation}" is not among'
315
                    f' the supported values: {supported_values}'
316
                )
317
                raise AttributeError(message)
318
        else:
319
            message = (
320
                'image_interpolation must be a string,'
321
                f' not {type(interpolation)}'
322
            )
323
            raise TypeError(message)
324
        return interpolation
325
326
    def _get_subject_from_tensor(self, tensor: torch.Tensor) -> Subject:
327
        image = ScalarImage(tensor=tensor)
328
        return self._get_subject_from_image(image)
329
330
    def _get_subject_from_image(self, image: Image) -> Subject:
331
        subject = Subject({self.default_image_name: image})
332
        return subject
333
334
    @staticmethod
335
    def _get_subject_from_dict(
336
            data: dict,
337
            image_keys: List[str],
338
            ) -> Subject:
339
        subject_dict = {}
340
        for key, value in data.items():
341
            if key in image_keys:
342
                value = ScalarImage(tensor=value)
343
            subject_dict[key] = value
344
        return Subject(subject_dict)
345
346
    def _get_subject_from_sitk_image(self, image):
347
        tensor, affine = sitk_to_nib(image)
348
        image = ScalarImage(tensor=tensor, affine=affine)
349
        return self._get_subject_from_image(image)
350
351
    @staticmethod
352
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
353
        return nib_to_sitk(data, affine)
354
355
    @staticmethod
356
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
357
        return sitk_to_nib(image)
358
359
    @property
360
    def name(self):
361
        return self.__class__.__name__
362