Passed
Pull Request — master (#353)
by Fernando
01:11
created

torchio.transforms.transform.Transform.name()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import copy
2
import numbers
3
import warnings
4
from abc import ABC, abstractmethod
5
from typing import Optional, Union, Tuple, List
6
7
import torch
8
import numpy as np
9
import nibabel as nib
10
import SimpleITK as sitk
11
12
from .. import TypeData, DATA, AFFINE, TypeNumber
13
from ..data.subject import Subject
14
from ..data.image import Image, ScalarImage
15
from ..utils import nib_to_sitk, sitk_to_nib, is_jsonable, to_tuple
16
from .interpolation import Interpolation
17
18
19
TypeTransformInput = Union[
20
    Subject,
21
    Image,
22
    torch.Tensor,
23
    np.ndarray,
24
    sitk.Image,
25
    dict,
26
]
27
28
29
class Transform(ABC):
30
    """Abstract class for all TorchIO transforms.
31
32
    All subclasses should overwrite
33
    :py:meth:`torchio.tranforms.Transform.apply_transform`,
34
    which takes data, applies some transformation and returns the result.
35
36
    The input can be an instance of
37
    :py:class:`torchio.Subject`,
38
    :py:class:`torchio.Image`,
39
    :py:class:`numpy.ndarray`,
40
    :py:class:`torch.Tensor`,
41
    :py:class:`SimpleITK.image`,
42
    or a Python dictionary.
43
44
    Args:
45
        p: Probability that this transform will be applied.
46
        copy: Make a shallow copy of the input before applying the transform.
47
        keys: Mandatory if the input is a Python dictionary. The transform will
48
            be applied only to the data in each key.
49
    """
50
    def __init__(
51
            self,
52
            p: float = 1,
53
            copy: bool = True,
54
            keys: Optional[List[str]] = None,
55
            ):
56
        self.probability = self.parse_probability(p)
57
        self.copy = copy
58
        self.keys = keys
59
        self.default_image_name = 'default_image_name'
60
        self.transform_params = {}
61
62
    def __call__(
63
            self,
64
            data: TypeTransformInput,
65
            ) -> TypeTransformInput:
66
        """Transform data and return a result of the same type.
67
68
        Args:
69
            data: Instance of :py:class:`~torchio.Subject`, 4D
70
                :py:class:`torch.Tensor` or 4D NumPy array with dimensions
71
                :math:`(C, W, H, D)`, where :math:`C` is the number of channels
72
                and :math:`W, H, D` are the spatial dimensions. If the input is
73
                a tensor, the affine matrix is an identity and a tensor will be
74
                also returned.
75
        """
76
        self.transform_params = {}
77
        self._store_params()
78
79
        if torch.rand(1).item() > self.probability:
80
            return data
81
82
        is_tensor = is_array = is_dict = is_image = is_sitk = is_nib = False
83
84
        if isinstance(data, nib.Nifti1Image):
85
            tensor = data.get_fdata(dtype=np.float32)
86
            data = ScalarImage(tensor=tensor, affine=data.affine)
87
            subject = self._get_subject_from_image(data)
88
            is_nib = True
89
        elif isinstance(data, (np.ndarray, torch.Tensor)):
90
            subject = self.parse_tensor(data)
91
            is_array = isinstance(data, np.ndarray)
92
            is_tensor = True
93
        elif isinstance(data, Image):
94
            subject = self._get_subject_from_image(data)
95
            is_image = True
96
        elif isinstance(data, Subject):
97
            subject = data
98
        elif isinstance(data, sitk.Image):
99
            subject = self._get_subject_from_sitk_image(data)
100
            is_sitk = True
101
        elif isinstance(data, dict):  # e.g. Eisen or MONAI dicts
102
            if self.keys is None:
103
                message = (
104
                    'If input is a dictionary, a value for "keys" must be'
105
                    ' specified when instantiating the transform'
106
                )
107
                raise RuntimeError(message)
108
            subject = self._get_subject_from_dict(data, self.keys)
109
            is_dict = True
110
        else:
111
            raise ValueError(f'Input type not recognized: {type(data)}')
112
        self.parse_subject(subject)
113
114
        if self.copy:
115
            subject = copy.copy(subject)
116
117
        with np.errstate(all='raise'):
118
            transformed = self.apply_transform(subject)
119
120
        for image in transformed.get_images(intensity_only=False):
121
            ndim = image[DATA].ndim
122
            assert ndim == 4, f'Output of {self.name} is {ndim}D'
123
124
        if is_tensor or is_sitk:
125
            image = transformed[self.default_image_name]
126
            transformed = image[DATA]
127
            if is_array:
128
                transformed = transformed.numpy()
129
            elif is_sitk:
130
                transformed = nib_to_sitk(image[DATA], image[AFFINE])
131
        elif is_image:
132
            transformed = transformed[self.default_image_name]
133
        elif is_dict:
134
            transformed = dict(transformed)
135
            for key, value in transformed.items():
136
                if isinstance(value, Image):
137
                    transformed[key] = value.data
138
        elif is_nib:
139
            image = transformed[self.default_image_name]
140
            data = image[DATA]
141
            if len(data) > 1:
142
                message = (
143
                    'Multichannel images not supported for input of type'
144
                    ' nibabel.nifti.Nifti1Image'
145
                )
146
                raise RuntimeError(message)
147
            transformed = nib.Nifti1Image(data[0].numpy(), image[AFFINE])
148
149
        # # If not a Compose
150
        # if isinstance(transformed, Subject) and not (self.name in ['Compose', 'OneOf']):
151
        #     transformed.add_transform(
152
        #         self,
153
        #         parameters_dict=self.transform_params,
154
        #     )
155
156
        return transformed
157
158
    def _store_params(self):
159
        self.transform_params.update(self.__dict__.copy())
160
        del self.transform_params['transform_params']
161
        for key, value in self.transform_params.items():
162
            if not is_jsonable(value):
163
                self.transform_params[key] = value.__str__()
164
165
    @abstractmethod
166
    def apply_transform(self, subject: Subject):
167
        raise NotImplementedError
168
169
    @staticmethod
170
    def to_range(n, around):
171
        if around is None:
172
            return 0, n
173
        else:
174
            return around - n, around + n
175
176
    def parse_params(self, params, around, name, make_ranges=True, **kwargs):
177
        params = to_tuple(params)
178
        if len(params) == 1 or (len(params) == 2 and make_ranges):  # d or (a, b)
179
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
180
        if len(params) == 3 and make_ranges:  # (a, b, c)
181
            items = [self.to_range(n, around) for n in params]
182
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
183
            params = [n for prange in items for n in prange]
184
        if make_ranges and len(params) != 6:
185
            if len(params) != 6:
186
                message = (
187
                    f'If "{name}" is a sequence, it must have length 2, 3 or 6,'
188
                    f' not {len(params)}'
189
                )
190
                raise ValueError(message)
191
            for param_range in zip(params[::2], params[1::2]):
192
                self.parse_range(param_range, name, **kwargs)
193
        return tuple(params)
194
195
    @staticmethod
196
    def parse_range(
197
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
198
            name: str,
199
            min_constraint: TypeNumber = None,
200
            max_constraint: TypeNumber = None,
201
            type_constraint: type = None,
202
            ) -> Tuple[TypeNumber, TypeNumber]:
203
        r"""Adapted from ``torchvision.transforms.RandomRotation``.
204
205
        Args:
206
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
207
                where :math:`n_{min} \leq n_{max}`.
208
                If a single positive number :math:`n` is provided,
209
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
210
            name: Name of the parameter, so that an informative error message
211
                can be printed.
212
            min_constraint: Minimal value that :math:`n_{min}` can take,
213
                default is None, i.e. there is no minimal value.
214
            max_constraint: Maximal value that :math:`n_{max}` can take,
215
                default is None, i.e. there is no maximal value.
216
            type_constraint: Precise type that :math:`n_{max}` and
217
                :math:`n_{min}` must take.
218
219
        Returns:
220
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
221
222
        Raises:
223
            ValueError: if :attr:`nums_range` is negative
224
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
225
            ValueError: if :math:`n_{max} \lt n_{min}`
226
            ValueError: if :attr:`min_constraint` is not None and
227
                :math:`n_{min}` is smaller than :attr:`min_constraint`
228
            ValueError: if :attr:`max_constraint` is not None and
229
                :math:`n_{max}` is greater than :attr:`max_constraint`
230
            ValueError: if :attr:`type_constraint` is not None and
231
                :math:`n_{max}` and :math:`n_{max}` are not of type
232
                :attr:`type_constraint`.
233
        """
234
        if isinstance(nums_range, numbers.Number):  # single number given
235
            if nums_range < 0:
236
                raise ValueError(
237
                    f'If {name} is a single number,'
238
                    f' it must be positive, not {nums_range}')
239
            if min_constraint is not None and nums_range < min_constraint:
240
                raise ValueError(
241
                    f'If {name} is a single number, it must be greater'
242
                    f' than {min_constraint}, not {nums_range}'
243
                )
244
            if max_constraint is not None and nums_range > max_constraint:
245
                raise ValueError(
246
                    f'If {name} is a single number, it must be smaller'
247
                    f' than {max_constraint}, not {nums_range}'
248
                )
249
            if type_constraint is not None:
250
                if not isinstance(nums_range, type_constraint):
251
                    raise ValueError(
252
                        f'If {name} is a single number, it must be of'
253
                        f' type {type_constraint}, not {nums_range}'
254
                    )
255
            min_range = -nums_range if min_constraint is None else nums_range
256
            return (min_range, nums_range)
257
258
        try:
259
            min_value, max_value = nums_range
260
        except (TypeError, ValueError):
261
            raise ValueError(
262
                f'If {name} is not a single number, it must be'
263
                f' a sequence of len 2, not {nums_range}'
264
            )
265
266
        min_is_number = isinstance(min_value, numbers.Number)
267
        max_is_number = isinstance(max_value, numbers.Number)
268
        if not min_is_number or not max_is_number:
269
            message = (
270
                f'{name} values must be numbers, not {nums_range}')
271
            raise ValueError(message)
272
273
        if min_value > max_value:
274
            raise ValueError(
275
                f'If {name} is a sequence, the second value must be'
276
                f' equal or greater than the first, but it is {nums_range}')
277
278
        if min_constraint is not None and min_value < min_constraint:
279
            raise ValueError(
280
                f'If {name} is a sequence, the first value must be greater'
281
                f' than {min_constraint}, but it is {min_value}'
282
            )
283
284
        if max_constraint is not None and max_value > max_constraint:
285
            raise ValueError(
286
                f'If {name} is a sequence, the second value must be smaller'
287
                f' than {max_constraint}, but it is {max_value}'
288
            )
289
290
        if type_constraint is not None:
291
            min_type_ok = isinstance(min_value, type_constraint)
292
            max_type_ok = isinstance(max_value, type_constraint)
293
            if not min_type_ok or not max_type_ok:
294
                raise ValueError(
295
                    f'If "{name}" is a sequence, its values must be of'
296
                    f' type "{type_constraint}", not "{type(nums_range)}"'
297
                )
298
        return nums_range
299
300
    @staticmethod
301
    def parse_probability(probability: float) -> float:
302
        is_number = isinstance(probability, numbers.Number)
303
        if not (is_number and 0 <= probability <= 1):
304
            message = (
305
                'Probability must be a number in [0, 1],'
306
                f' not {probability}'
307
            )
308
            raise ValueError(message)
309
        return probability
310
311
    @staticmethod
312
    def parse_subject(subject: Subject) -> None:
313
        if not isinstance(subject, Subject):
314
            message = (
315
                'Input to a transform must be a tensor or an instance'
316
                f' of torchio.Subject, not "{type(subject)}"'
317
            )
318
            raise RuntimeError(message)
319
320
    def parse_tensor(self, data: TypeData) -> Subject:
321
        if data.ndim != 4:
322
            message = (
323
                'The input must be a 4D tensor with dimensions'
324
                f' (channels, x, y, z) but it has shape {tuple(data.shape)}'
325
            )
326
            raise ValueError(message)
327
        return self._get_subject_from_tensor(data)
328
329
    @staticmethod
330
    def parse_interpolation(interpolation: str) -> Interpolation:
331
        if isinstance(interpolation, Interpolation):
332
            message = (
333
                'Interpolation of type torchio.Interpolation'
334
                ' is deprecated, please use a string instead'
335
            )
336
            warnings.warn(message, FutureWarning)
337
        elif isinstance(interpolation, str):
338
            interpolation = interpolation.lower()
339
            supported_values = [key.name.lower() for key in Interpolation]
340
            if interpolation in supported_values:
341
                interpolation = getattr(Interpolation, interpolation.upper())
342
            else:
343
                message = (
344
                    f'Interpolation "{interpolation}" is not among'
345
                    f' the supported values: {supported_values}'
346
                )
347
                raise AttributeError(message)
348
        else:
349
            message = (
350
                'image_interpolation must be a string,'
351
                f' not {type(interpolation)}'
352
            )
353
            raise TypeError(message)
354
        return interpolation
355
356
    def _get_subject_from_tensor(self, tensor: torch.Tensor) -> Subject:
357
        image = ScalarImage(tensor=tensor)
358
        return self._get_subject_from_image(image)
359
360
    def _get_subject_from_image(self, image: Image) -> Subject:
361
        subject = Subject({self.default_image_name: image})
362
        return subject
363
364
    @staticmethod
365
    def _get_subject_from_dict(
366
            data: dict,
367
            image_keys: List[str],
368
            ) -> Subject:
369
        subject_dict = {}
370
        for key, value in data.items():
371
            if key in image_keys:
372
                value = ScalarImage(tensor=value)
373
            subject_dict[key] = value
374
        return Subject(subject_dict)
375
376
    def _get_subject_from_sitk_image(self, image):
377
        tensor, affine = sitk_to_nib(image)
378
        image = ScalarImage(tensor=tensor, affine=affine)
379
        return self._get_subject_from_image(image)
380
381
    @staticmethod
382
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
383
        return nib_to_sitk(data, affine)
384
385
    @staticmethod
386
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
387
        return sitk_to_nib(image)
388
389
    @property
390
    def name(self):
391
        return self.__class__.__name__
392
393
    def is_invertible(self):
394
        inverse_method = getattr(self, 'inverse', None)
395
        return inverse_method is not None and callable(inverse_method)
396