torchio.transforms.transform.Transform.inverse()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 1
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
import copy
2
import numbers
3
import warnings
4
from abc import ABC, abstractmethod
5
from contextlib import contextmanager
6
from typing import Union, Tuple, Optional, Dict
7
8
import torch
9
import numpy as np
10
import SimpleITK as sitk
11
12
from ..utils import to_tuple
13
from ..data.subject import Subject
14
from ..data.io import nib_to_sitk, sitk_to_nib
15
from ..data.image import LabelMap
16
from ..typing import (
17
    TypeKeys,
18
    TypeData,
19
    TypeNumber,
20
    TypeCallable,
21
    TypeTripletInt,
22
)
23
from .interpolation import Interpolation, get_sitk_interpolator
24
from .data_parser import DataParser, TypeTransformInput
25
26
TypeSixBounds = Tuple[int, int, int, int, int, int]
27
TypeBounds = Union[
28
    int,
29
    TypeTripletInt,
30
    TypeSixBounds,
31
    None,
32
]
33
TypeMaskingMethod = Union[str, TypeCallable, TypeBounds, None]
34
ANATOMICAL_AXES = (
35
    'Left', 'Right',
36
    'Posterior', 'Anterior',
37
    'Inferior', 'Superior',
38
)
39
40
41
class Transform(ABC):
42
    """Abstract class for all TorchIO transforms.
43
44
    When called, the input can be an instance of
45
    :class:`torchio.Subject`,
46
    :class:`torchio.Image`,
47
    :class:`numpy.ndarray`,
48
    :class:`torch.Tensor`,
49
    :class:`SimpleITK.Image`,
50
    or :class:`dict` containing 4D tensors as values.
51
52
    All subclasses must overwrite
53
    :meth:`~torchio.transforms.Transform.apply_transform`,
54
    which takes an instance of :class:`~torchio.Subject`,
55
    modifies it and returns the result.
56
57
    Args:
58
        p: Probability that this transform will be applied.
59
        copy: Make a shallow copy of the input before applying the transform.
60
        include: Sequence of strings with the names of the only images to which
61
            the transform will be applied.
62
            Mandatory if the input is a :class:`dict`.
63
        exclude: Sequence of strings with the names of the images to which the
64
            the transform will not be applied, apart from the ones that are
65
            excluded because of the transform type.
66
            For example, if a subject includes an MRI, a CT and a label map,
67
            and the CT is added to the list of exclusions of an intensity
68
            transform such as :class:`~torchio.transforms.RandomBlur`,
69
            the transform will be only applied to the MRI, as the label map is
70
            excluded by default by spatial transforms.
71
        keep: Dictionary with the names of the images that will be kept in the
72
            subject and their new names.
73
        parse_input: If ``True``, the input will be converted to an instance of
74
            :class:`~torchio.Subject`. This is used internally by some special
75
            transforms like
76
            :class:`~torchio.transforms.augmentation.composition.Compose`.
77
    """
78
    def __init__(
79
            self,
80
            p: float = 1,
81
            copy: bool = True,
82
            include: TypeKeys = None,
83
            exclude: TypeKeys = None,
84
            keys: TypeKeys = None,
85
            keep: Optional[Dict[str, str]] = None,
86
            parse_input: bool = True,
87
            ):
88
        self.probability = self.parse_probability(p)
89
        self.copy = copy
90
        if keys is not None:
91
            message = (
92
                'The "keys" argument is deprecated and will be removed in the'
93
                ' future. Use "include" instead.'
94
            )
95
            warnings.warn(message)
96
            include = keys
97
        self.include, self.exclude = self.parse_include_and_exclude(
98
            include, exclude)
99
        self.keep = keep
100
        self.parse_input = parse_input
101
        # args_names is the sequence of parameters from self that need to be
102
        # passed to a non-random version of a random transform. They are also
103
        # used to invert invertible transforms
104
        self.args_names = ()
105
106
    def __call__(
107
            self,
108
            data: TypeTransformInput,
109
            ) -> TypeTransformInput:
110
        """Transform data and return a result of the same type.
111
112
        Args:
113
            data: Instance of :class:`torchio.Subject`, 4D
114
                :class:`torch.Tensor` or :class:`numpy.ndarray` with dimensions
115
                :math:`(C, W, H, D)`, where :math:`C` is the number of channels
116
                and :math:`W, H, D` are the spatial dimensions. If the input is
117
                a tensor, the affine matrix will be set to identity. Other
118
                valid input types are a SimpleITK image, a
119
                :class:`torchio.Image`, a NiBabel Nifti1 image or a
120
                :class:`dict`. The output type is the same as the input type.
121
        """
122
        if torch.rand(1).item() > self.probability:
123
            return data
124
125
        # Some transforms such as Compose should not modify the input data
126
        if self.parse_input:
127
            data_parser = DataParser(data, keys=self.include)
128
            subject = data_parser.get_subject()
129
        else:
130
            subject = data
131
132
        if self.keep is not None:
133
            images_to_keep = {}
134
            for name, new_name in self.keep.items():
135
                images_to_keep[new_name] = copy.copy(subject[name])
136
        if self.copy:
137
            subject = copy.copy(subject)
138
        with np.errstate(all='raise', under='ignore'):
139
            transformed = self.apply_transform(subject)
140
        if self.keep is not None:
141
            for name, image in images_to_keep.items():
0 ignored issues
show
introduced by
The variable images_to_keep does not seem to be defined in case self.keep is not None on line 132 is False. Are you sure this can never be the case?
Loading history...
142
                transformed.add_image(image, name)
143
144
        if self.parse_input:
145
            self.add_transform_to_subject_history(transformed)
146
            for image in transformed.get_images(intensity_only=False):
147
                ndim = image.data.ndim
148
                assert ndim == 4, f'Output of {self.name} is {ndim}D'
149
            output = data_parser.get_output(transformed)
0 ignored issues
show
introduced by
The variable data_parser does not seem to be defined for all execution paths.
Loading history...
150
        else:
151
            output = transformed
152
153
        return output
154
155
    def __repr__(self):
156
        if hasattr(self, 'args_names'):
157
            names = self.args_names
158
            args_strings = [f'{arg}={getattr(self, arg)}' for arg in names]
159
            if hasattr(self, 'invert_transform') and self.invert_transform:
160
                args_strings.append('invert=True')
161
            args_string = ', '.join(args_strings)
162
            return f'{self.name}({args_string})'
163
        else:
164
            return super().__repr__()
165
166
    @property
167
    def name(self):
168
        return self.__class__.__name__
169
170
    @abstractmethod
171
    def apply_transform(self, subject: Subject) -> Subject:
172
        raise NotImplementedError
173
174
    def add_transform_to_subject_history(self, subject):
175
        from .augmentation import RandomTransform
176
        from . import Compose, OneOf, CropOrPad, EnsureShapeMultiple
177
        from .preprocessing import SequentialLabels, Resize
178
        call_others = (
179
            RandomTransform,
180
            Compose,
181
            OneOf,
182
            CropOrPad,
183
            EnsureShapeMultiple,
184
            SequentialLabels,
185
            Resize,
186
        )
187
        if not isinstance(self, call_others):
188
            subject.add_transform(self, self._get_reproducing_arguments())
189
190
    @staticmethod
191
    def to_range(n, around):
192
        if around is None:
193
            return 0, n
194
        else:
195
            return around - n, around + n
196
197
    def parse_params(self, params, around, name, make_ranges=True, **kwargs):
198
        params = to_tuple(params)
199
        # d or (a, b)
200
        if len(params) == 1 or (len(params) == 2 and make_ranges):
201
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
202
        if len(params) == 3 and make_ranges:  # (a, b, c)
203
            items = [self.to_range(n, around) for n in params]
204
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
205
            params = [n for prange in items for n in prange]
206
        if make_ranges:
207
            if len(params) != 6:
208
                message = (
209
                    f'If "{name}" is a sequence, it must have length 2, 3 or'
210
                    f' 6, not {len(params)}'
211
                )
212
                raise ValueError(message)
213
            for param_range in zip(params[::2], params[1::2]):
214
                self._parse_range(param_range, name, **kwargs)
215
        return tuple(params)
216
217
    @staticmethod
218
    def _parse_range(
219
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
220
            name: str,
221
            min_constraint: TypeNumber = None,
222
            max_constraint: TypeNumber = None,
223
            type_constraint: type = None,
224
            ) -> Tuple[TypeNumber, TypeNumber]:
225
        r"""Adapted from :class:`torchvision.transforms.RandomRotation`.
226
227
        Args:
228
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
229
                where :math:`n_{min} \leq n_{max}`.
230
                If a single positive number :math:`n` is provided,
231
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
232
            name: Name of the parameter, so that an informative error message
233
                can be printed.
234
            min_constraint: Minimal value that :math:`n_{min}` can take,
235
                default is None, i.e. there is no minimal value.
236
            max_constraint: Maximal value that :math:`n_{max}` can take,
237
                default is None, i.e. there is no maximal value.
238
            type_constraint: Precise type that :math:`n_{max}` and
239
                :math:`n_{min}` must take.
240
241
        Returns:
242
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
243
244
        Raises:
245
            ValueError: if :attr:`nums_range` is negative
246
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
247
            ValueError: if :math:`n_{max} \lt n_{min}`
248
            ValueError: if :attr:`min_constraint` is not None and
249
                :math:`n_{min}` is smaller than :attr:`min_constraint`
250
            ValueError: if :attr:`max_constraint` is not None and
251
                :math:`n_{max}` is greater than :attr:`max_constraint`
252
            ValueError: if :attr:`type_constraint` is not None and
253
                :math:`n_{max}` and :math:`n_{max}` are not of type
254
                :attr:`type_constraint`.
255
        """
256
        if isinstance(nums_range, numbers.Number):  # single number given
257
            if nums_range < 0:
258
                raise ValueError(
259
                    f'If {name} is a single number,'
260
                    f' it must be positive, not {nums_range}')
261
            if min_constraint is not None and nums_range < min_constraint:
262
                raise ValueError(
263
                    f'If {name} is a single number, it must be greater'
264
                    f' than {min_constraint}, not {nums_range}'
265
                )
266
            if max_constraint is not None and nums_range > max_constraint:
267
                raise ValueError(
268
                    f'If {name} is a single number, it must be smaller'
269
                    f' than {max_constraint}, not {nums_range}'
270
                )
271
            if type_constraint is not None:
272
                if not isinstance(nums_range, type_constraint):
273
                    raise ValueError(
274
                        f'If {name} is a single number, it must be of'
275
                        f' type {type_constraint}, not {nums_range}'
276
                    )
277
            min_range = -nums_range if min_constraint is None else nums_range
278
            return (min_range, nums_range)
279
280
        try:
281
            min_value, max_value = nums_range
282
        except (TypeError, ValueError):
283
            raise ValueError(
284
                f'If {name} is not a single number, it must be'
285
                f' a sequence of len 2, not {nums_range}'
286
            )
287
288
        min_is_number = isinstance(min_value, numbers.Number)
289
        max_is_number = isinstance(max_value, numbers.Number)
290
        if not min_is_number or not max_is_number:
291
            message = (
292
                f'{name} values must be numbers, not {nums_range}')
293
            raise ValueError(message)
294
295
        if min_value > max_value:
296
            raise ValueError(
297
                f'If {name} is a sequence, the second value must be'
298
                f' equal or greater than the first, but it is {nums_range}')
299
300
        if min_constraint is not None and min_value < min_constraint:
301
            raise ValueError(
302
                f'If {name} is a sequence, the first value must be greater'
303
                f' than {min_constraint}, but it is {min_value}'
304
            )
305
306
        if max_constraint is not None and max_value > max_constraint:
307
            raise ValueError(
308
                f'If {name} is a sequence, the second value must be smaller'
309
                f' than {max_constraint}, but it is {max_value}'
310
            )
311
312
        if type_constraint is not None:
313
            min_type_ok = isinstance(min_value, type_constraint)
314
            max_type_ok = isinstance(max_value, type_constraint)
315
            if not min_type_ok or not max_type_ok:
316
                raise ValueError(
317
                    f'If "{name}" is a sequence, its values must be of'
318
                    f' type "{type_constraint}", not "{type(nums_range)}"'
319
                )
320
        return nums_range
321
322
    @staticmethod
323
    def parse_interpolation(interpolation: str) -> str:
324
        if not isinstance(interpolation, str):
325
            itype = type(interpolation)
326
            raise TypeError(f'Interpolation must be a string, not {itype}')
327
        interpolation = interpolation.lower()
328
        is_string = isinstance(interpolation, str)
329
        supported_values = [key.name.lower() for key in Interpolation]
330
        is_supported = interpolation.lower() in supported_values
331
        if is_string and is_supported:
332
            return interpolation
333
        message = (
334
            f'Interpolation "{interpolation}" of type {type(interpolation)}'
335
            f' must be a string among the supported values: {supported_values}'
336
        )
337
        raise ValueError(message)
338
339
    @staticmethod
340
    def parse_probability(probability: float) -> float:
341
        is_number = isinstance(probability, numbers.Number)
342
        if not (is_number and 0 <= probability <= 1):
343
            message = (
344
                'Probability must be a number in [0, 1],'
345
                f' not {probability}'
346
            )
347
            raise ValueError(message)
348
        return probability
349
350
    @staticmethod
351
    def parse_include_and_exclude(
352
            include: TypeKeys = None,
353
            exclude: TypeKeys = None,
354
            ) -> Tuple[TypeKeys, TypeKeys]:
355
        if include is not None and exclude is not None:
356
            raise ValueError('Include and exclude cannot both be specified')
357
        return include, exclude
358
359
    @staticmethod
360
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
361
        return nib_to_sitk(data, affine)
362
363
    @staticmethod
364
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
365
        return sitk_to_nib(image)
366
367
    def _get_reproducing_arguments(self):
368
        """
369
        Return a dictionary with the arguments that would be necessary to
370
        reproduce the transform exactly.
371
        """
372
        reproducing_arguments = {
373
            'include': self.include,
374
            'exclude': self.exclude,
375
            'copy': self.copy,
376
        }
377
        args_names = {name: getattr(self, name) for name in self.args_names}
378
        reproducing_arguments.update(args_names)
379
        return reproducing_arguments
380
381
    def is_invertible(self):
382
        return hasattr(self, 'invert_transform')
383
384
    def inverse(self):
385
        if not self.is_invertible():
386
            raise RuntimeError(f'{self.name} is not invertible')
387
        new = copy.deepcopy(self)
388
        new.invert_transform = not self.invert_transform
389
        return new
390
391
    @staticmethod
392
    @contextmanager
393
    def _use_seed(seed):
394
        """Perform an operation using a specific seed for the PyTorch RNG"""
395
        torch_rng_state = torch.random.get_rng_state()
396
        torch.manual_seed(seed)
397
        yield
398
        torch.random.set_rng_state(torch_rng_state)
399
400
    @staticmethod
401
    def get_sitk_interpolator(interpolation: str) -> int:
402
        return get_sitk_interpolator(interpolation)
403
404
    @staticmethod
405
    def parse_bounds(bounds_parameters: TypeBounds) -> TypeSixBounds:
406
        if bounds_parameters is None:
407
            return None
408
        try:
409
            bounds_parameters = tuple(bounds_parameters)
410
        except TypeError:
411
            bounds_parameters = (bounds_parameters,)
412
413
        # Check that numbers are integers
414
        for number in bounds_parameters:
415
            if not isinstance(number, (int, np.integer)) or number < 0:
416
                message = (
417
                    'Bounds values must be integers greater or equal to zero,'
418
                    f' not "{bounds_parameters}" of type {type(number)}'
419
                )
420
                raise ValueError(message)
421
        bounds_parameters = tuple(int(n) for n in bounds_parameters)
422
        bounds_parameters_length = len(bounds_parameters)
423
        if bounds_parameters_length == 6:
424
            return bounds_parameters
425
        if bounds_parameters_length == 1:
426
            return 6 * bounds_parameters
427
        if bounds_parameters_length == 3:
428
            return tuple(np.repeat(bounds_parameters, 2).tolist())
429
        message = (
430
            'Bounds parameter must be an integer or a tuple of'
431
            f' 3 or 6 integers, not {bounds_parameters}'
432
        )
433
        raise ValueError(message)
434
435
    @staticmethod
436
    def ones(tensor: torch.Tensor) -> torch.Tensor:
437
        return torch.ones_like(tensor, dtype=torch.bool)
438
439
    @staticmethod
440
    def mean(tensor: torch.Tensor) -> torch.Tensor:
441
        mask = tensor > tensor.float().mean()
442
        return mask
443
444
    def get_mask_from_masking_method(
445
            self,
446
            masking_method: TypeMaskingMethod,
447
            subject: Subject,
448
            tensor: torch.Tensor,
449
            labels: list = None,
450
            ) -> torch.Tensor:
451
        if masking_method is None:
452
            return self.ones(tensor)
453
        elif callable(masking_method):
454
            return masking_method(tensor)
455
        elif type(masking_method) is str:
456
            in_subject = masking_method in subject
457
            if in_subject and isinstance(subject[masking_method], LabelMap):
458
                if labels is None:
459
                    return subject[masking_method].data.bool()
460
                else:
461
                    mask_data = subject[masking_method].data
462
                    volumes = [mask_data == label for label in labels]
463
                    return torch.stack(volumes).sum(0).bool()
464
            possible_axis = masking_method.capitalize()
465
            if possible_axis in ANATOMICAL_AXES:
466
                return self.get_mask_from_anatomical_label(
467
                    possible_axis, tensor)
468
        elif type(masking_method) in (tuple, list, int):
469
            return self.get_mask_from_bounds(masking_method, tensor)
470
        first_anat_axes = tuple(s[0] for s in ANATOMICAL_AXES)
471
        message = (
472
            'Masking method must be one of:\n'
473
            ' 1) A callable object, such as a function\n'
474
            ' 2) The name of a label map in the subject'
475
            f' ({subject.get_images_names()})\n'
476
            f' 3) An anatomical label {ANATOMICAL_AXES + first_anat_axes}\n'
477
            ' 4) A bounds parameter'
478
            ' (int, tuple of 3 ints, or tuple of 6 ints)\n'
479
            f' The passed value, "{masking_method}",'
480
            f' of type "{type(masking_method)}", is not valid'
481
        )
482
        raise ValueError(message)
483
484
    @staticmethod
485
    def get_mask_from_anatomical_label(
486
            anatomical_label: str,
487
            tensor: torch.Tensor,
488
            ) -> torch.Tensor:
489
        # Assume the image is in RAS orientation
490
        anatomical_label = anatomical_label.capitalize()
491
        if anatomical_label not in ANATOMICAL_AXES:
492
            message = (
493
                f'Anatomical label must be one of {ANATOMICAL_AXES}'
494
                f' not {anatomical_label}'
495
            )
496
            raise ValueError(message)
497
        mask = torch.zeros_like(tensor, dtype=torch.bool)
498
        _, width, height, depth = tensor.shape
499
        if anatomical_label == 'Right':
500
            mask[:, width // 2:] = True
501
        elif anatomical_label == 'Left':
502
            mask[:, :width // 2] = True
503
        elif anatomical_label == 'Anterior':
504
            mask[:, :, height // 2:] = True
505
        elif anatomical_label == 'Posterior':
506
            mask[:, :, :height // 2] = True
507
        elif anatomical_label == 'Superior':
508
            mask[:, :, :, depth // 2:] = True
509
        elif anatomical_label == 'Inferior':
510
            mask[:, :, :, :depth // 2] = True
511
        return mask
512
513
    def get_mask_from_bounds(
514
            self,
515
            bounds_parameters: TypeBounds,
516
            tensor: torch.Tensor,
517
            ) -> torch.Tensor:
518
        bounds_parameters = self.parse_bounds(bounds_parameters)
519
        low = bounds_parameters[::2]
520
        high = bounds_parameters[1::2]
521
        i0, j0, k0 = low
522
        i1, j1, k1 = np.array(tensor.shape[1:]) - high
523
        mask = torch.zeros_like(tensor, dtype=torch.bool)
524
        mask[:, i0:i1, j0:j1, k0:k1] = True
525
        return mask
526