Passed
Pull Request — master (#403)
by Fernando
01:12
created

torchio.transforms.transform.Transform.name()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import copy
2
import numbers
3
import warnings
4
from typing import Union, Tuple
5
from abc import ABC, abstractmethod
6
from contextlib import contextmanager
7
8
import torch
9
import numpy as np
10
import SimpleITK as sitk
11
12
from ..utils import to_tuple
13
from ..data.subject import Subject
14
from ..data.io import nib_to_sitk, sitk_to_nib
15
from ..data.image import LabelMap
16
from ..typing import (
17
    TypeKeys,
18
    TypeData,
19
    TypeNumber,
20
    TypeCallable,
21
    TypeTripletInt,
22
)
23
from .interpolation import Interpolation, get_sitk_interpolator
24
from .data_parser import DataParser, TypeTransformInput
25
26
TypeSixBounds = Tuple[int, int, int, int, int, int]
27
TypeBounds = Union[
28
    int,
29
    TypeTripletInt,
30
    TypeSixBounds,
31
]
32
TypeMaskingMethod = Union[str, TypeCallable, TypeBounds, None]
33
anat_axes = 'Left', 'Right', 'Anterior', 'Posterior', 'Inferior', 'Superior'
34
35
36
class Transform(ABC):
37
    """Abstract class for all TorchIO transforms.
38
39
    All subclasses must overwrite
40
    :meth:`Transform.apply_transform`,
41
    which takes data, applies some transformation and returns the result.
42
43
    The input can be an instance of
44
    :class:`torchio.Subject`,
45
    :class:`torchio.Image`,
46
    :class:`numpy.ndarray`,
47
    :class:`torch.Tensor`,
48
    :class:`SimpleITK.Image`,
49
    or :class:`dict`.
50
51
    Args:
52
        p: Probability that this transform will be applied.
53
        copy: Make a shallow copy of the input before applying the transform.
54
        include: Sequence of strings with the names of the only images to which
55
            the transform will be applied.
56
            Mandatory if the input is a :class:`dict`.
57
        exclude: Sequence of strings with the names of the images to which the
58
            the transform will not be applied, apart from the ones that are
59
            excluded because of the transform type.
60
            For example, if a subject includes an MRI, a CT and a label map,
61
            and the CT is added to the list of exclusions of an intensity
62
            transform such as :class:`~torchio.transforms.RandomBlur`,
63
            the transform will be only applied to the MRI, as the label map is
64
            excluded by default by spatial transforms.
65
    """
66
    def __init__(
67
            self,
68
            p: float = 1,
69
            copy: bool = True,
70
            include: TypeKeys = None,
71
            exclude: TypeKeys = None,
72
            keys: TypeKeys = None,
73
            ):
74
        self.probability = self.parse_probability(p)
75
        self.copy = copy
76
        if keys is not None:
77
            message = (
78
                'The "keys" argument is deprecated and will be removed in the'
79
                ' future. Use "include" instead.'
80
            )
81
            warnings.warn(message, DeprecationWarning)
82
            include = keys
83
        self.include, self.exclude = self.parse_include_and_exclude(
84
            include, exclude)
85
86
    def __call__(
87
            self,
88
            data: TypeTransformInput,
89
            ) -> TypeTransformInput:
90
        """Transform data and return a result of the same type.
91
92
        Args:
93
            data: Instance of 1) :class:`~torchio.Subject`, 4D
94
                :class:`torch.Tensor` or :class:`numpy.ndarray` with dimensions
95
                :math:`(C, W, H, D)`, where :math:`C` is the number of channels
96
                and :math:`W, H, D` are the spatial dimensions. If the input is
97
                a tensor, the affine matrix will be set to identity. Other
98
                valid input types are a SimpleITK image, a
99
                :class:`torchio.Image`, a NiBabel Nifti1 image or a
100
                :class:`dict`. The output type is the same as the input type.
101
        """
102
        if torch.rand(1).item() > self.probability:
103
            return data
104
        data_parser = DataParser(data, keys=self.include)
105
        subject = data_parser.get_subject()
106
        if self.copy:
107
            subject = copy.copy(subject)
108
        with np.errstate(all='raise'):
109
            transformed = self.apply_transform(subject)
110
        self.add_transform_to_subject_history(transformed)
111
        for image in transformed.get_images(intensity_only=False):
112
            ndim = image.data.ndim
113
            assert ndim == 4, f'Output of {self.name} is {ndim}D'
114
115
        output = data_parser.get_output(transformed)
116
        return output
117
118
    def __repr__(self):
119
        if hasattr(self, 'args_names'):
120
            names = self.args_names
121
            args_strings = [f'{arg}={getattr(self, arg)}' for arg in names]
122
            if hasattr(self, 'invert_transform') and self.invert_transform:
123
                args_strings.append('invert=True')
124
            args_string = ', '.join(args_strings)
125
            return f'{self.name}({args_string})'
126
        else:
127
            return super().__repr__()
128
129
    @property
130
    def name(self):
131
        return self.__class__.__name__
132
133
    @abstractmethod
134
    def apply_transform(self, subject: Subject):
135
        raise NotImplementedError
136
137
    def add_transform_to_subject_history(self, subject):
138
        from .augmentation import RandomTransform
139
        from . import Compose, OneOf, CropOrPad, EnsureShapeMultiple
140
        from .preprocessing.label import SequentialLabels
141
        call_others = (
142
            RandomTransform,
143
            Compose,
144
            OneOf,
145
            CropOrPad,
146
            EnsureShapeMultiple,
147
            SequentialLabels,
148
        )
149
        if not isinstance(self, call_others):
150
            subject.add_transform(self, self._get_reproducing_arguments())
151
152
    @staticmethod
153
    def to_range(n, around):
154
        if around is None:
155
            return 0, n
156
        else:
157
            return around - n, around + n
158
159
    def parse_params(self, params, around, name, make_ranges=True, **kwargs):
160
        params = to_tuple(params)
161
        # d or (a, b)
162
        if len(params) == 1 or (len(params) == 2 and make_ranges):
163
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
164
        if len(params) == 3 and make_ranges:  # (a, b, c)
165
            items = [self.to_range(n, around) for n in params]
166
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
167
            params = [n for prange in items for n in prange]
168
        if make_ranges:
169
            if len(params) != 6:
170
                message = (
171
                    f'If "{name}" is a sequence, it must have length 2, 3 or'
172
                    f' 6, not {len(params)}'
173
                )
174
                raise ValueError(message)
175
            for param_range in zip(params[::2], params[1::2]):
176
                self._parse_range(param_range, name, **kwargs)
177
        return tuple(params)
178
179
    @staticmethod
180
    def _parse_range(
181
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
182
            name: str,
183
            min_constraint: TypeNumber = None,
184
            max_constraint: TypeNumber = None,
185
            type_constraint: type = None,
186
            ) -> Tuple[TypeNumber, TypeNumber]:
187
        r"""Adapted from :class:`torchvision.transforms.RandomRotation`.
188
189
        Args:
190
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
191
                where :math:`n_{min} \leq n_{max}`.
192
                If a single positive number :math:`n` is provided,
193
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
194
            name: Name of the parameter, so that an informative error message
195
                can be printed.
196
            min_constraint: Minimal value that :math:`n_{min}` can take,
197
                default is None, i.e. there is no minimal value.
198
            max_constraint: Maximal value that :math:`n_{max}` can take,
199
                default is None, i.e. there is no maximal value.
200
            type_constraint: Precise type that :math:`n_{max}` and
201
                :math:`n_{min}` must take.
202
203
        Returns:
204
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
205
206
        Raises:
207
            ValueError: if :attr:`nums_range` is negative
208
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
209
            ValueError: if :math:`n_{max} \lt n_{min}`
210
            ValueError: if :attr:`min_constraint` is not None and
211
                :math:`n_{min}` is smaller than :attr:`min_constraint`
212
            ValueError: if :attr:`max_constraint` is not None and
213
                :math:`n_{max}` is greater than :attr:`max_constraint`
214
            ValueError: if :attr:`type_constraint` is not None and
215
                :math:`n_{max}` and :math:`n_{max}` are not of type
216
                :attr:`type_constraint`.
217
        """
218
        if isinstance(nums_range, numbers.Number):  # single number given
219
            if nums_range < 0:
220
                raise ValueError(
221
                    f'If {name} is a single number,'
222
                    f' it must be positive, not {nums_range}')
223
            if min_constraint is not None and nums_range < min_constraint:
224
                raise ValueError(
225
                    f'If {name} is a single number, it must be greater'
226
                    f' than {min_constraint}, not {nums_range}'
227
                )
228
            if max_constraint is not None and nums_range > max_constraint:
229
                raise ValueError(
230
                    f'If {name} is a single number, it must be smaller'
231
                    f' than {max_constraint}, not {nums_range}'
232
                )
233
            if type_constraint is not None:
234
                if not isinstance(nums_range, type_constraint):
235
                    raise ValueError(
236
                        f'If {name} is a single number, it must be of'
237
                        f' type {type_constraint}, not {nums_range}'
238
                    )
239
            min_range = -nums_range if min_constraint is None else nums_range
240
            return (min_range, nums_range)
241
242
        try:
243
            min_value, max_value = nums_range
244
        except (TypeError, ValueError):
245
            raise ValueError(
246
                f'If {name} is not a single number, it must be'
247
                f' a sequence of len 2, not {nums_range}'
248
            )
249
250
        min_is_number = isinstance(min_value, numbers.Number)
251
        max_is_number = isinstance(max_value, numbers.Number)
252
        if not min_is_number or not max_is_number:
253
            message = (
254
                f'{name} values must be numbers, not {nums_range}')
255
            raise ValueError(message)
256
257
        if min_value > max_value:
258
            raise ValueError(
259
                f'If {name} is a sequence, the second value must be'
260
                f' equal or greater than the first, but it is {nums_range}')
261
262
        if min_constraint is not None and min_value < min_constraint:
263
            raise ValueError(
264
                f'If {name} is a sequence, the first value must be greater'
265
                f' than {min_constraint}, but it is {min_value}'
266
            )
267
268
        if max_constraint is not None and max_value > max_constraint:
269
            raise ValueError(
270
                f'If {name} is a sequence, the second value must be smaller'
271
                f' than {max_constraint}, but it is {max_value}'
272
            )
273
274
        if type_constraint is not None:
275
            min_type_ok = isinstance(min_value, type_constraint)
276
            max_type_ok = isinstance(max_value, type_constraint)
277
            if not min_type_ok or not max_type_ok:
278
                raise ValueError(
279
                    f'If "{name}" is a sequence, its values must be of'
280
                    f' type "{type_constraint}", not "{type(nums_range)}"'
281
                )
282
        return nums_range
283
284
    @staticmethod
285
    def parse_interpolation(interpolation: str) -> str:
286
        if not isinstance(interpolation, str):
287
            itype = type(interpolation)
288
            raise TypeError(f'Interpolation must be a string, not {itype}')
289
        interpolation = interpolation.lower()
290
        is_string = isinstance(interpolation, str)
291
        supported_values = [key.name.lower() for key in Interpolation]
292
        is_supported = interpolation.lower() in supported_values
293
        if is_string and is_supported:
294
            return interpolation
295
        message = (
296
            f'Interpolation "{interpolation}" of type {type(interpolation)}'
297
            f' must be a string among the supported values: {supported_values}'
298
        )
299
        raise ValueError(message)
300
301
    @staticmethod
302
    def parse_probability(probability: float) -> float:
303
        is_number = isinstance(probability, numbers.Number)
304
        if not (is_number and 0 <= probability <= 1):
305
            message = (
306
                'Probability must be a number in [0, 1],'
307
                f' not {probability}'
308
            )
309
            raise ValueError(message)
310
        return probability
311
312
    @staticmethod
313
    def parse_include_and_exclude(
314
            include: TypeKeys = None,
315
            exclude: TypeKeys = None,
316
            ) -> Tuple[TypeKeys, TypeKeys]:
317
        if include is not None and exclude is not None:
318
            raise ValueError('Include and exclude cannot both be specified')
319
        return include, exclude
320
321
    @staticmethod
322
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
323
        return nib_to_sitk(data, affine)
324
325
    @staticmethod
326
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
327
        return sitk_to_nib(image)
328
329
    def _get_reproducing_arguments(self):
330
        """
331
        Return a dictionary with the arguments that would be necessary to
332
        reproduce the transform exactly.
333
        """
334
        reproducing_arguments = {
335
            'include': self.include,
336
            'exclude': self.exclude,
337
            'copy': self.copy,
338
        }
339
        args_names = {name: getattr(self, name) for name in self.args_names}
340
        reproducing_arguments.update(args_names)
341
        return reproducing_arguments
342
343
    def is_invertible(self):
344
        return hasattr(self, 'invert_transform')
345
346
    def inverse(self):
347
        if not self.is_invertible():
348
            raise RuntimeError(f'{self.name} is not invertible')
349
        new = copy.deepcopy(self)
350
        new.invert_transform = not self.invert_transform
351
        return new
352
353
    @staticmethod
354
    @contextmanager
355
    def _use_seed(seed):
356
        """Perform an operation using a specific seed for the PyTorch RNG"""
357
        torch_rng_state = torch.random.get_rng_state()
358
        torch.manual_seed(seed)
359
        yield
360
        torch.random.set_rng_state(torch_rng_state)
361
362
    @staticmethod
363
    def get_sitk_interpolator(interpolation: str) -> int:
364
        return get_sitk_interpolator(interpolation)
365
366
    @staticmethod
367
    def parse_bounds(bounds_parameters: TypeBounds) -> TypeSixBounds:
368
        try:
369
            bounds_parameters = tuple(bounds_parameters)
370
        except TypeError:
371
            bounds_parameters = (bounds_parameters,)
372
373
        # Check that numbers are integers
374
        for number in bounds_parameters:
375
            if not isinstance(number, (int, np.integer)) or number < 0:
376
                message = (
377
                    'Bounds values must be integers greater or equal to zero,'
378
                    f' not "{bounds_parameters}" of type {type(number)}'
379
                )
380
                raise ValueError(message)
381
        bounds_parameters = tuple(int(n) for n in bounds_parameters)
382
        bounds_parameters_length = len(bounds_parameters)
383
        if bounds_parameters_length == 6:
384
            return bounds_parameters
385
        if bounds_parameters_length == 1:
386
            return 6 * bounds_parameters
387
        if bounds_parameters_length == 3:
388
            return tuple(np.repeat(bounds_parameters, 2).tolist())
389
        message = (
390
            'Bounds parameter must be an integer or a tuple of'
391
            f' 3 or 6 integers, not {bounds_parameters}'
392
        )
393
        raise ValueError(message)
394
395
    @staticmethod
396
    def ones(tensor: torch.Tensor) -> torch.Tensor:
397
        return torch.ones_like(tensor, dtype=torch.bool)
398
399
    @staticmethod
400
    def mean(tensor: torch.Tensor) -> torch.Tensor:
401
        mask = tensor > tensor.mean()
402
        return mask
403
404
    @staticmethod
405
    def get_mask_from_masking_method(
406
            masking_method: TypeMaskingMethod,
407
            subject: Subject,
408
            tensor: torch.Tensor,
409
            ) -> torch.Tensor:
410
        if masking_method is None:
411
            return Transform.ones(tensor)
412
        elif callable(masking_method):
413
            return masking_method(tensor)
414
        elif type(masking_method) is str:
415
            in_subject = masking_method in subject
416
            if in_subject and isinstance(subject[masking_method], LabelMap):
417
                return subject[masking_method].data.bool()
418
            masking_method = masking_method.capitalize()
419
            if masking_method in anat_axes:
420
                return Transform.get_mask_from_anatomical_label(
421
                    masking_method, tensor)
422
        elif type(masking_method) in (tuple, list, int):
423
            return Transform.get_mask_from_bounds(masking_method, tensor)
424
        message = (
425
            'Masking method parameter must be a function, a label map name,'
426
            f' an anatomical label: {anat_axes}, or a bounds parameter'
427
            ' (an int, tuple of 3 ints, or tuple of 6 ints),'
428
            f' not {masking_method} of type {type(masking_method)}'
429
        )
430
        raise ValueError(message)
431
432
    @staticmethod
433
    def get_mask_from_anatomical_label(
434
            anatomical_label: str,
435
            tensor: torch.Tensor,
436
            ) -> torch.Tensor:
437
        anatomical_label = anatomical_label.title()
438
        if anatomical_label.title() not in anat_axes:
439
            message = (
440
                f'Anatomical label must be one of {anat_axes}'
441
                f' not {anatomical_label}'
442
            )
443
            raise ValueError(message)
444
        mask = torch.zeros_like(tensor, dtype=torch.bool)
445
        _, width, height, depth = tensor.shape
446
        if anatomical_label == 'Right':
447
            mask[:, width // 2:] = True
448
        elif anatomical_label == 'Left':
449
            mask[:, :width // 2] = True
450
        elif anatomical_label == 'Anterior':
451
            mask[:, :, height // 2:] = True
452
        elif anatomical_label == 'Posterior':
453
            mask[:, :, :height // 2] = True
454
        elif anatomical_label == 'Superior':
455
            mask[:, :, :, depth // 2:] = True
456
        elif anatomical_label == 'Inferior':
457
            mask[:, :, :, :depth // 2] = True
458
        return mask
459
460
    @staticmethod
461
    def get_mask_from_bounds(
462
            bounds_parameters: TypeBounds,
463
            tensor: torch.Tensor,
464
            ) -> torch.Tensor:
465
        bounds_parameters = Transform.parse_bounds(bounds_parameters)
466
        low = bounds_parameters[::2]
467
        high = bounds_parameters[1::2]
468
        i0, j0, k0 = low
469
        i1, j1, k1 = np.array(tensor.shape[1:]) - high
470
        mask = torch.zeros_like(tensor, dtype=torch.bool)
471
        mask[:, i0:i1, j0:j1, k0:k1] = True
472
        return mask
473