Passed
Pull Request — master (#353)
by Fernando
01:16
created

torchio.transforms.transform.Transform.name()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import copy
2
import numbers
3
from abc import ABC, abstractmethod
4
from typing import Optional, Union, Tuple, List, Sequence
5
6
import torch
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from .. import TypeData, DATA, AFFINE, TypeNumber
12
from ..data.subject import Subject
13
from ..data.image import Image, ScalarImage
14
from ..utils import nib_to_sitk, sitk_to_nib, to_tuple
15
from .interpolation import Interpolation
16
17
18
TypeTransformInput = Union[
19
    Subject,
20
    Image,
21
    torch.Tensor,
22
    np.ndarray,
23
    sitk.Image,
24
    dict,
25
    nib.Nifti1Image,
26
]
27
28
29
class Transform(ABC):
30
    """Abstract class for all TorchIO transforms.
31
32
    All subclasses should overwrite
33
    :py:meth:`torchio.tranforms.Transform.apply_transform`,
34
    which takes data, applies some transformation and returns the result.
35
36
    The input can be an instance of
37
    :py:class:`torchio.Subject`,
38
    :py:class:`torchio.Image`,
39
    :py:class:`numpy.ndarray`,
40
    :py:class:`torch.Tensor`,
41
    :py:class:`SimpleITK.image`,
42
    or a Python dictionary.
43
44
    Args:
45
        p: Probability that this transform will be applied.
46
        copy: Make a shallow copy of the input before applying the transform.
47
        keys: Mandatory if the input is a Python dictionary. The transform will
48
            be applied only to the data in each key.
49
    """
50
    def __init__(
51
            self,
52
            p: float = 1,
53
            copy: bool = True,
54
            keys: Optional[Sequence[str]] = None,
55
            ):
56
        self.probability = self.parse_probability(p)
57
        self.copy = copy
58
        self.keys = keys
59
60
    def __call__(
61
            self,
62
            data: TypeTransformInput,
63
            ) -> TypeTransformInput:
64
        """Transform data and return a result of the same type.
65
66
        Args:
67
            data: Instance of 1) :py:class:`~torchio.Subject`, 4D
68
                :py:class:`torch.Tensor` or NumPy array with dimensions
69
                :math:`(C, W, H, D)`, where :math:`C` is the number of channels
70
                and :math:`W, H, D` are the spatial dimensions. If the input is
71
                a tensor, the affine matrix will be set to identity. Other
72
                valid input types are a SimpleITK image, a
73
                :py:class:`torch.Image`, a NiBabel Nifti1 Image or a Python
74
                dictionary. The output type is the same as te input type.
75
        """
76
        if torch.rand(1).item() > self.probability:
77
            return data
78
        data_parser = DataToSubject(data, keys=self.keys)
79
        subject = data_parser.get_subject()
80
        if self.copy:
81
            subject = copy.copy(subject)
82
        with np.errstate(all='raise'):
83
            transformed = self.apply_transform(subject)
84
        self.add_transform_to_subject_history(transformed)
85
        for image in transformed.get_images(intensity_only=False):
86
            ndim = image[DATA].ndim
87
            assert ndim == 4, f'Output of {self.name} is {ndim}D'
88
        output = data_parser.get_output(transformed)
89
        return output
90
91
    def __repr__(self):
92
        if hasattr(self, 'args_names'):
93
            names = self.args_names
94
            args_strings = [f'{arg}={getattr(self, arg)}' for arg in names]
95
            if hasattr(self, 'invert_transform') and self.invert_transform:
96
                args_strings.append('invert=True')
97
            args_string = ', '.join(args_strings)
98
            return f'{self.name}({args_string})'
99
        else:
100
            return super().__repr__()
101
102
    @property
103
    def name(self):
104
        return self.__class__.__name__
105
106
    @abstractmethod
107
    def apply_transform(self, subject: Subject):
108
        raise NotImplementedError
109
110
    def add_transform_to_subject_history(self, subject):
111
        from .augmentation import RandomTransform
112
        from . import Compose, OneOf, CropOrPad
113
        call_others = (
114
            RandomTransform,
115
            Compose,
116
            OneOf,
117
            CropOrPad,
118
        )
119
        if not isinstance(self, call_others):
120
            subject.add_transform(self, self.get_arguments())
121
122
    @staticmethod
123
    def to_range(n, around):
124
        if around is None:
125
            return 0, n
126
        else:
127
            return around - n, around + n
128
129
    def parse_params(self, params, around, name, make_ranges=True, **kwargs):
130
        params = to_tuple(params)
131
        if len(params) == 1 or (len(params) == 2 and make_ranges):  # d or (a, b)
132
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
133
        if len(params) == 3 and make_ranges:  # (a, b, c)
134
            items = [self.to_range(n, around) for n in params]
135
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
136
            params = [n for prange in items for n in prange]
137
        if make_ranges:
138
            if len(params) != 6:
139
                message = (
140
                    f'If "{name}" is a sequence, it must have length 2, 3 or 6,'
141
                    f' not {len(params)}'
142
                )
143
                raise ValueError(message)
144
            for param_range in zip(params[::2], params[1::2]):
145
                self.parse_range(param_range, name, **kwargs)
146
        return tuple(params)
147
148
    @staticmethod
149
    def parse_range(
150
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
151
            name: str,
152
            min_constraint: TypeNumber = None,
153
            max_constraint: TypeNumber = None,
154
            type_constraint: type = None,
155
            ) -> Tuple[TypeNumber, TypeNumber]:
156
        r"""Adapted from ``torchvision.transforms.RandomRotation``.
157
158
        Args:
159
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
160
                where :math:`n_{min} \leq n_{max}`.
161
                If a single positive number :math:`n` is provided,
162
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
163
            name: Name of the parameter, so that an informative error message
164
                can be printed.
165
            min_constraint: Minimal value that :math:`n_{min}` can take,
166
                default is None, i.e. there is no minimal value.
167
            max_constraint: Maximal value that :math:`n_{max}` can take,
168
                default is None, i.e. there is no maximal value.
169
            type_constraint: Precise type that :math:`n_{max}` and
170
                :math:`n_{min}` must take.
171
172
        Returns:
173
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
174
175
        Raises:
176
            ValueError: if :attr:`nums_range` is negative
177
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
178
            ValueError: if :math:`n_{max} \lt n_{min}`
179
            ValueError: if :attr:`min_constraint` is not None and
180
                :math:`n_{min}` is smaller than :attr:`min_constraint`
181
            ValueError: if :attr:`max_constraint` is not None and
182
                :math:`n_{max}` is greater than :attr:`max_constraint`
183
            ValueError: if :attr:`type_constraint` is not None and
184
                :math:`n_{max}` and :math:`n_{max}` are not of type
185
                :attr:`type_constraint`.
186
        """
187
        if isinstance(nums_range, numbers.Number):  # single number given
188
            if nums_range < 0:
189
                raise ValueError(
190
                    f'If {name} is a single number,'
191
                    f' it must be positive, not {nums_range}')
192
            if min_constraint is not None and nums_range < min_constraint:
193
                raise ValueError(
194
                    f'If {name} is a single number, it must be greater'
195
                    f' than {min_constraint}, not {nums_range}'
196
                )
197
            if max_constraint is not None and nums_range > max_constraint:
198
                raise ValueError(
199
                    f'If {name} is a single number, it must be smaller'
200
                    f' than {max_constraint}, not {nums_range}'
201
                )
202
            if type_constraint is not None:
203
                if not isinstance(nums_range, type_constraint):
204
                    raise ValueError(
205
                        f'If {name} is a single number, it must be of'
206
                        f' type {type_constraint}, not {nums_range}'
207
                    )
208
            min_range = -nums_range if min_constraint is None else nums_range
209
            return (min_range, nums_range)
210
211
        try:
212
            min_value, max_value = nums_range
213
        except (TypeError, ValueError):
214
            raise ValueError(
215
                f'If {name} is not a single number, it must be'
216
                f' a sequence of len 2, not {nums_range}'
217
            )
218
219
        min_is_number = isinstance(min_value, numbers.Number)
220
        max_is_number = isinstance(max_value, numbers.Number)
221
        if not min_is_number or not max_is_number:
222
            message = (
223
                f'{name} values must be numbers, not {nums_range}')
224
            raise ValueError(message)
225
226
        if min_value > max_value:
227
            raise ValueError(
228
                f'If {name} is a sequence, the second value must be'
229
                f' equal or greater than the first, but it is {nums_range}')
230
231
        if min_constraint is not None and min_value < min_constraint:
232
            raise ValueError(
233
                f'If {name} is a sequence, the first value must be greater'
234
                f' than {min_constraint}, but it is {min_value}'
235
            )
236
237
        if max_constraint is not None and max_value > max_constraint:
238
            raise ValueError(
239
                f'If {name} is a sequence, the second value must be smaller'
240
                f' than {max_constraint}, but it is {max_value}'
241
            )
242
243
        if type_constraint is not None:
244
            min_type_ok = isinstance(min_value, type_constraint)
245
            max_type_ok = isinstance(max_value, type_constraint)
246
            if not min_type_ok or not max_type_ok:
247
                raise ValueError(
248
                    f'If "{name}" is a sequence, its values must be of'
249
                    f' type "{type_constraint}", not "{type(nums_range)}"'
250
                )
251
        return nums_range
252
253
    @staticmethod
254
    def parse_interpolation(interpolation: str) -> str:
255
        if not isinstance(interpolation, str):
256
            itype = type(interpolation)
257
            raise TypeError(f'Interpolation must be a string, not {itype}')
258
        interpolation = interpolation.lower()
259
        is_string = isinstance(interpolation, str)
260
        supported_values = [key.name.lower() for key in Interpolation]
261
        is_supported = interpolation.lower() in supported_values
262
        if is_string and is_supported:
263
            return interpolation
264
        message = (
265
            f'Interpolation "{interpolation}" of type {type(interpolation)}'
266
            f' must be a string among the supported values: {supported_values}'
267
        )
268
        raise ValueError(message)
269
270
    @staticmethod
271
    def parse_probability(probability: float) -> float:
272
        is_number = isinstance(probability, numbers.Number)
273
        if not (is_number and 0 <= probability <= 1):
274
            message = (
275
                'Probability must be a number in [0, 1],'
276
                f' not {probability}'
277
            )
278
            raise ValueError(message)
279
        return probability
280
281
    @staticmethod
282
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
283
        return nib_to_sitk(data, affine)
284
285
    @staticmethod
286
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
287
        return sitk_to_nib(image)
288
289
    @staticmethod
290
    def fourier_transform(array: np.ndarray):
291
        transformed = np.fft.fftn(array)
292
        fshift = np.fft.fftshift(transformed)
293
        return fshift
294
295
    @staticmethod
296
    def inv_fourier_transform(fshift: np.ndarray):
297
        f_ishift = np.fft.ifftshift(fshift)
298
        img_back = np.fft.ifftn(f_ishift)
299
        return img_back
300
301
    def get_arguments(self):
302
        """
303
        Return a dictionary with the arguments that would be necessary to
304
        reproduce the transform exactly.
305
        """
306
        return {name: getattr(self, name) for name in self.args_names}
307
308
    def is_invertible(self):
309
        return hasattr(self, 'invert_transform')
310
311
    def inverse(self):
312
        if not self.is_invertible():
313
            raise RuntimeError(f'{self.name} is not invertible')
314
        new = copy.deepcopy(self)
315
        new.invert_transform = not self.invert_transform
316
        return new
317
318
319
class DataToSubject:
320
    def __init__(
321
            self,
322
            data: TypeTransformInput,
323
            keys: Optional[Sequence[str]] = None,
324
            ):
325
        self.data = data
326
        self.keys = keys
327
        self.default_image_name = 'default_image_name'
328
        self.is_tensor = False
329
        self.is_array = False
330
        self.is_dict = False
331
        self.is_image = False
332
        self.is_sitk = False
333
        self.is_nib = False
334
335
    def get_subject(self):
336
        if isinstance(self.data, nib.Nifti1Image):
337
            tensor = self.data.get_fdata(dtype=np.float32)
338
            data = ScalarImage(tensor=tensor, affine=self.data.affine)
339
            subject = self._get_subject_from_image(data)
340
            self.is_nib = True
341
        elif isinstance(self.data, (np.ndarray, torch.Tensor)):
342
            subject = self._parse_tensor(self.data)
343
            self.is_array = isinstance(self.data, np.ndarray)
344
            self.is_tensor = True
345
        elif isinstance(self.data, Image):
346
            subject = self._get_subject_from_image(self.data)
347
            self.is_image = True
348
        elif isinstance(self.data, Subject):
349
            subject = self.data
350
        elif isinstance(self.data, sitk.Image):
351
            subject = self._get_subject_from_sitk_image(self.data)
352
            self.is_sitk = True
353
        elif isinstance(self.data, dict):  # e.g. Eisen or MONAI dicts
354
            if self.keys is None:
355
                message = (
356
                    'If input is a dictionary, a value for "keys" must be'
357
                    ' specified when instantiating the transform'
358
                )
359
                raise RuntimeError(message)
360
            subject = self._get_subject_from_dict(self.data, self.keys)
361
            self.is_dict = True
362
        else:
363
            raise ValueError(f'Input type not recognized: {type(self.data)}')
364
        self._parse_subject(subject)
365
        return subject
366
367
    def get_output(self, transformed):
368
        if self.is_tensor or self.is_sitk:
369
            image = transformed[self.default_image_name]
370
            transformed = image[DATA]
371
            if self.is_array:
372
                transformed = transformed.numpy()
373
            elif self.is_sitk:
374
                transformed = nib_to_sitk(image[DATA], image[AFFINE])
375
        elif self.is_image:
376
            transformed = transformed[self.default_image_name]
377
        elif self.is_dict:
378
            transformed = dict(transformed)
379
            for key, value in transformed.items():
380
                if isinstance(value, Image):
381
                    transformed[key] = value.data
382
        elif self.is_nib:
383
            image = transformed[self.default_image_name]
384
            data = image[DATA]
385
            if len(data) > 1:
386
                message = (
387
                    'Multichannel images not supported for input of type'
388
                    ' nibabel.nifti.Nifti1Image'
389
                )
390
                raise RuntimeError(message)
391
            transformed = nib.Nifti1Image(data[0].numpy(), image[AFFINE])
392
        return transformed
393
394
    @staticmethod
395
    def _parse_subject(subject: Subject) -> None:
396
        if not isinstance(subject, Subject):
397
            message = (
398
                'Input to a transform must be a tensor or an instance'
399
                f' of torchio.Subject, not "{type(subject)}"'
400
            )
401
            raise RuntimeError(message)
402
403
    def _parse_tensor(self, data: TypeData) -> Subject:
404
        if data.ndim != 4:
405
            message = (
406
                'The input must be a 4D tensor with dimensions'
407
                f' (channels, x, y, z) but it has shape {tuple(data.shape)}'
408
            )
409
            raise ValueError(message)
410
        return self._get_subject_from_tensor(data)
411
412
    def _get_subject_from_tensor(self, tensor: torch.Tensor) -> Subject:
413
        image = ScalarImage(tensor=tensor)
414
        return self._get_subject_from_image(image)
415
416
    def _get_subject_from_image(self, image: Image) -> Subject:
417
        subject = Subject({self.default_image_name: image})
418
        return subject
419
420
    @staticmethod
421
    def _get_subject_from_dict(
422
            data: dict,
423
            image_keys: List[str],
424
            ) -> Subject:
425
        subject_dict = {}
426
        for key, value in data.items():
427
            if key in image_keys:
428
                value = ScalarImage(tensor=value)
429
            subject_dict[key] = value
430
        return Subject(subject_dict)
431
432
    def _get_subject_from_sitk_image(self, image):
433
        tensor, affine = sitk_to_nib(image)
434
        image = ScalarImage(tensor=tensor, affine=affine)
435
        return self._get_subject_from_image(image)
436