Passed
Pull Request — master (#353)
by Fernando
01:16
created

Transform.parse_subject()   A

Complexity

Conditions 2

Size

Total Lines 8
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 1
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
import copy
2
import numbers
3
import warnings
4
from abc import ABC, abstractmethod
5
from typing import Optional, Union, Tuple, List
6
7
import torch
8
import numpy as np
9
import nibabel as nib
10
import SimpleITK as sitk
11
12
from .. import TypeData, DATA, AFFINE, TypeNumber
13
from ..data.subject import Subject
14
from ..data.image import Image, ScalarImage
15
from ..utils import nib_to_sitk, sitk_to_nib, to_tuple
16
from .interpolation import Interpolation
17
18
19
TypeTransformInput = Union[
20
    Subject,
21
    Image,
22
    torch.Tensor,
23
    np.ndarray,
24
    sitk.Image,
25
    dict,
26
]
27
28
29
class Transform(ABC):
30
    """Abstract class for all TorchIO transforms.
31
32
    All subclasses should overwrite
33
    :py:meth:`torchio.tranforms.Transform.apply_transform`,
34
    which takes data, applies some transformation and returns the result.
35
36
    The input can be an instance of
37
    :py:class:`torchio.Subject`,
38
    :py:class:`torchio.Image`,
39
    :py:class:`numpy.ndarray`,
40
    :py:class:`torch.Tensor`,
41
    :py:class:`SimpleITK.image`,
42
    or a Python dictionary.
43
44
    Args:
45
        p: Probability that this transform will be applied.
46
        copy: Make a shallow copy of the input before applying the transform.
47
        keys: Mandatory if the input is a Python dictionary. The transform will
48
            be applied only to the data in each key.
49
    """
50
    def __init__(
51
            self,
52
            p: float = 1,
53
            copy: bool = True,
54
            keys: Optional[List[str]] = None,
55
            ):
56
        self.probability = self.parse_probability(p)
57
        self.copy = copy
58
        self.keys = keys
59
        self.default_image_name = 'default_image_name'
60
61
    def __call__(
62
            self,
63
            data: TypeTransformInput,
64
            ) -> TypeTransformInput:
65
        """Transform data and return a result of the same type.
66
67
        Args:
68
            data: Instance of :py:class:`~torchio.Subject`, 4D
69
                :py:class:`torch.Tensor` or 4D NumPy array with dimensions
70
                :math:`(C, W, H, D)`, where :math:`C` is the number of channels
71
                and :math:`W, H, D` are the spatial dimensions. If the input is
72
                a tensor, the affine matrix is an identity and a tensor will be
73
                also returned.
74
        """
75
        if torch.rand(1).item() > self.probability:
76
            return data
77
78
        is_tensor = is_array = is_dict = is_image = is_sitk = is_nib = False
79
80
        if isinstance(data, nib.Nifti1Image):
81
            tensor = data.get_fdata(dtype=np.float32)
82
            data = ScalarImage(tensor=tensor, affine=data.affine)
83
            subject = self._get_subject_from_image(data)
84
            is_nib = True
85
        elif isinstance(data, (np.ndarray, torch.Tensor)):
86
            subject = self.parse_tensor(data)
87
            is_array = isinstance(data, np.ndarray)
88
            is_tensor = True
89
        elif isinstance(data, Image):
90
            subject = self._get_subject_from_image(data)
91
            is_image = True
92
        elif isinstance(data, Subject):
93
            subject = data
94
        elif isinstance(data, sitk.Image):
95
            subject = self._get_subject_from_sitk_image(data)
96
            is_sitk = True
97
        elif isinstance(data, dict):  # e.g. Eisen or MONAI dicts
98
            if self.keys is None:
99
                message = (
100
                    'If input is a dictionary, a value for "keys" must be'
101
                    ' specified when instantiating the transform'
102
                )
103
                raise RuntimeError(message)
104
            subject = self._get_subject_from_dict(data, self.keys)
105
            is_dict = True
106
        else:
107
            raise ValueError(f'Input type not recognized: {type(data)}')
108
        self.parse_subject(subject)
109
110
        if self.copy:
111
            subject = copy.copy(subject)
112
113
        with np.errstate(all='raise'):
114
            transformed = self.apply_transform(subject)
115
            from .augmentation import RandomTransform
116
            if not isinstance(self, RandomTransform):
117
                transformed.add_transform(self, self.get_arguments())
118
119
        for image in transformed.get_images(intensity_only=False):
120
            ndim = image[DATA].ndim
121
            assert ndim == 4, f'Output of {self.name} is {ndim}D'
122
123
        if is_tensor or is_sitk:
124
            image = transformed[self.default_image_name]
125
            transformed = image[DATA]
126
            if is_array:
127
                transformed = transformed.numpy()
128
            elif is_sitk:
129
                transformed = nib_to_sitk(image[DATA], image[AFFINE])
130
        elif is_image:
131
            transformed = transformed[self.default_image_name]
132
        elif is_dict:
133
            transformed = dict(transformed)
134
            for key, value in transformed.items():
135
                if isinstance(value, Image):
136
                    transformed[key] = value.data
137
        elif is_nib:
138
            image = transformed[self.default_image_name]
139
            data = image[DATA]
140
            if len(data) > 1:
141
                message = (
142
                    'Multichannel images not supported for input of type'
143
                    ' nibabel.nifti.Nifti1Image'
144
                )
145
                raise RuntimeError(message)
146
            transformed = nib.Nifti1Image(data[0].numpy(), image[AFFINE])
147
148
        return transformed
149
150
    @abstractmethod
151
    def apply_transform(self, subject: Subject):
152
        raise NotImplementedError
153
154
    @staticmethod
155
    def to_range(n, around):
156
        if around is None:
157
            return 0, n
158
        else:
159
            return around - n, around + n
160
161
    def parse_params(self, params, around, name, make_ranges=True, **kwargs):
162
        params = to_tuple(params)
163
        if len(params) == 1 or (len(params) == 2 and make_ranges):  # d or (a, b)
164
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
165
        if len(params) == 3 and make_ranges:  # (a, b, c)
166
            items = [self.to_range(n, around) for n in params]
167
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
168
            params = [n for prange in items for n in prange]
169
        if make_ranges and len(params) != 6:
170
            if len(params) != 6:
171
                message = (
172
                    f'If "{name}" is a sequence, it must have length 2, 3 or 6,'
173
                    f' not {len(params)}'
174
                )
175
                raise ValueError(message)
176
            for param_range in zip(params[::2], params[1::2]):
177
                self.parse_range(param_range, name, **kwargs)
178
        return tuple(params)
179
180
    @staticmethod
181
    def parse_range(
182
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
183
            name: str,
184
            min_constraint: TypeNumber = None,
185
            max_constraint: TypeNumber = None,
186
            type_constraint: type = None,
187
            ) -> Tuple[TypeNumber, TypeNumber]:
188
        r"""Adapted from ``torchvision.transforms.RandomRotation``.
189
190
        Args:
191
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
192
                where :math:`n_{min} \leq n_{max}`.
193
                If a single positive number :math:`n` is provided,
194
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
195
            name: Name of the parameter, so that an informative error message
196
                can be printed.
197
            min_constraint: Minimal value that :math:`n_{min}` can take,
198
                default is None, i.e. there is no minimal value.
199
            max_constraint: Maximal value that :math:`n_{max}` can take,
200
                default is None, i.e. there is no maximal value.
201
            type_constraint: Precise type that :math:`n_{max}` and
202
                :math:`n_{min}` must take.
203
204
        Returns:
205
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
206
207
        Raises:
208
            ValueError: if :attr:`nums_range` is negative
209
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
210
            ValueError: if :math:`n_{max} \lt n_{min}`
211
            ValueError: if :attr:`min_constraint` is not None and
212
                :math:`n_{min}` is smaller than :attr:`min_constraint`
213
            ValueError: if :attr:`max_constraint` is not None and
214
                :math:`n_{max}` is greater than :attr:`max_constraint`
215
            ValueError: if :attr:`type_constraint` is not None and
216
                :math:`n_{max}` and :math:`n_{max}` are not of type
217
                :attr:`type_constraint`.
218
        """
219
        if isinstance(nums_range, numbers.Number):  # single number given
220
            if nums_range < 0:
221
                raise ValueError(
222
                    f'If {name} is a single number,'
223
                    f' it must be positive, not {nums_range}')
224
            if min_constraint is not None and nums_range < min_constraint:
225
                raise ValueError(
226
                    f'If {name} is a single number, it must be greater'
227
                    f' than {min_constraint}, not {nums_range}'
228
                )
229
            if max_constraint is not None and nums_range > max_constraint:
230
                raise ValueError(
231
                    f'If {name} is a single number, it must be smaller'
232
                    f' than {max_constraint}, not {nums_range}'
233
                )
234
            if type_constraint is not None:
235
                if not isinstance(nums_range, type_constraint):
236
                    raise ValueError(
237
                        f'If {name} is a single number, it must be of'
238
                        f' type {type_constraint}, not {nums_range}'
239
                    )
240
            min_range = -nums_range if min_constraint is None else nums_range
241
            return (min_range, nums_range)
242
243
        try:
244
            min_value, max_value = nums_range
245
        except (TypeError, ValueError):
246
            raise ValueError(
247
                f'If {name} is not a single number, it must be'
248
                f' a sequence of len 2, not {nums_range}'
249
            )
250
251
        min_is_number = isinstance(min_value, numbers.Number)
252
        max_is_number = isinstance(max_value, numbers.Number)
253
        if not min_is_number or not max_is_number:
254
            message = (
255
                f'{name} values must be numbers, not {nums_range}')
256
            raise ValueError(message)
257
258
        if min_value > max_value:
259
            raise ValueError(
260
                f'If {name} is a sequence, the second value must be'
261
                f' equal or greater than the first, but it is {nums_range}')
262
263
        if min_constraint is not None and min_value < min_constraint:
264
            raise ValueError(
265
                f'If {name} is a sequence, the first value must be greater'
266
                f' than {min_constraint}, but it is {min_value}'
267
            )
268
269
        if max_constraint is not None and max_value > max_constraint:
270
            raise ValueError(
271
                f'If {name} is a sequence, the second value must be smaller'
272
                f' than {max_constraint}, but it is {max_value}'
273
            )
274
275
        if type_constraint is not None:
276
            min_type_ok = isinstance(min_value, type_constraint)
277
            max_type_ok = isinstance(max_value, type_constraint)
278
            if not min_type_ok or not max_type_ok:
279
                raise ValueError(
280
                    f'If "{name}" is a sequence, its values must be of'
281
                    f' type "{type_constraint}", not "{type(nums_range)}"'
282
                )
283
        return nums_range
284
285
    @staticmethod
286
    def parse_probability(probability: float) -> float:
287
        is_number = isinstance(probability, numbers.Number)
288
        if not (is_number and 0 <= probability <= 1):
289
            message = (
290
                'Probability must be a number in [0, 1],'
291
                f' not {probability}'
292
            )
293
            raise ValueError(message)
294
        return probability
295
296
    @staticmethod
297
    def parse_subject(subject: Subject) -> None:
298
        if not isinstance(subject, Subject):
299
            message = (
300
                'Input to a transform must be a tensor or an instance'
301
                f' of torchio.Subject, not "{type(subject)}"'
302
            )
303
            raise RuntimeError(message)
304
305
    def parse_tensor(self, data: TypeData) -> Subject:
306
        if data.ndim != 4:
307
            message = (
308
                'The input must be a 4D tensor with dimensions'
309
                f' (channels, x, y, z) but it has shape {tuple(data.shape)}'
310
            )
311
            raise ValueError(message)
312
        return self._get_subject_from_tensor(data)
313
314
    @staticmethod
315
    def parse_interpolation(interpolation: str) -> Interpolation:
316
        if isinstance(interpolation, Interpolation):
317
            message = (
318
                'Interpolation of type torchio.Interpolation'
319
                ' is deprecated, please use a string instead'
320
            )
321
            warnings.warn(message, FutureWarning)
322
        elif isinstance(interpolation, str):
323
            interpolation = interpolation.lower()
324
            supported_values = [key.name.lower() for key in Interpolation]
325
            if interpolation in supported_values:
326
                interpolation = getattr(Interpolation, interpolation.upper())
327
            else:
328
                message = (
329
                    f'Interpolation "{interpolation}" is not among'
330
                    f' the supported values: {supported_values}'
331
                )
332
                raise AttributeError(message)
333
        else:
334
            message = (
335
                'image_interpolation must be a string,'
336
                f' not {type(interpolation)}'
337
            )
338
            raise TypeError(message)
339
        return interpolation
340
341
    def _get_subject_from_tensor(self, tensor: torch.Tensor) -> Subject:
342
        image = ScalarImage(tensor=tensor)
343
        return self._get_subject_from_image(image)
344
345
    def _get_subject_from_image(self, image: Image) -> Subject:
346
        subject = Subject({self.default_image_name: image})
347
        return subject
348
349
    @staticmethod
350
    def _get_subject_from_dict(
351
            data: dict,
352
            image_keys: List[str],
353
            ) -> Subject:
354
        subject_dict = {}
355
        for key, value in data.items():
356
            if key in image_keys:
357
                value = ScalarImage(tensor=value)
358
            subject_dict[key] = value
359
        return Subject(subject_dict)
360
361
    def _get_subject_from_sitk_image(self, image):
362
        tensor, affine = sitk_to_nib(image)
363
        image = ScalarImage(tensor=tensor, affine=affine)
364
        return self._get_subject_from_image(image)
365
366
    @staticmethod
367
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
368
        return nib_to_sitk(data, affine)
369
370
    @staticmethod
371
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
372
        return sitk_to_nib(image)
373
374
    @property
375
    def name(self):
376
        return self.__class__.__name__
377
378
    def is_invertible(self):
379
        return hasattr(self, 'invert_transform')
380
381
    def inverse(self):
382
        if not self.is_invertible():
383
            raise RuntimeError(f'{self.name} is not invertible')
384
        new = copy.deepcopy(self)
385
        new.invert_transform = not self.invert_transform
386
        return new
387