Passed
Push — master ( b9ac52...6aebda )
by Fernando
10:37 queued 20s
created

torchio.transforms.transform.Transform.inverse()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 1
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
import copy
2
import numbers
3
import warnings
4
from typing import Union, Tuple
5
from abc import ABC, abstractmethod
6
from contextlib import contextmanager
7
8
import torch
9
import numpy as np
10
import SimpleITK as sitk
11
12
from ..utils import to_tuple
13
from ..data.subject import Subject
14
from ..data.io import nib_to_sitk, sitk_to_nib
15
from ..data.image import LabelMap
16
from ..typing import TypeData, TypeNumber, TypeKeys, TypeCallable, TypeTripletInt
17
from .interpolation import Interpolation, get_sitk_interpolator
18
from .data_parser import DataParser, TypeTransformInput
19
20
TypeSixBounds = Tuple[int, int, int, int, int, int]
21
TypeBounds = Union[
22
    int,
23
    TypeTripletInt,
24
    TypeSixBounds,
25
]
26
TypeMaskingMethod = Union[str, TypeCallable, TypeBounds, None]
27
28
29
class Transform(ABC):
30
    """Abstract class for all TorchIO transforms.
31
32
    All subclasses must overwrite
33
    :meth:`Transform.apply_transform`,
34
    which takes data, applies some transformation and returns the result.
35
36
    The input can be an instance of
37
    :class:`torchio.Subject`,
38
    :class:`torchio.Image`,
39
    :class:`numpy.ndarray`,
40
    :class:`torch.Tensor`,
41
    :class:`SimpleITK.Image`,
42
    or :class:`dict`.
43
44
    Args:
45
        p: Probability that this transform will be applied.
46
        copy: Make a shallow copy of the input before applying the transform.
47
        include: Sequence of strings with the names of the only images to which
48
            the transform will be applied.
49
            Mandatory if the input is a :class:`dict`.
50
        exclude: Sequence of strings with the names of the images to which the
51
            the transform will not be applied, apart from the ones that are
52
            excluded because of the transform type.
53
            For example, if a subject includes an MRI, a CT and a label map, and
54
            the CT is added to the list of exclusions of an intensity transform
55
            such as :class:`~torchio.transforms.RandomBlur`,
56
            the transform will be only applied to the MRI, as the label map is
57
            excluded by default by spatial transforms.
58
    """
59
    def __init__(
60
            self,
61
            p: float = 1,
62
            copy: bool = True,
63
            include: TypeKeys = None,
64
            exclude: TypeKeys = None,
65
            keys: TypeKeys = None,
66
            ):
67
        self.probability = self.parse_probability(p)
68
        self.copy = copy
69
        if keys is not None:
70
            message = (
71
                'The "keys" argument is deprecated and will be removed in the'
72
                ' future. Use "include" instead.'
73
            )
74
            warnings.warn(message, DeprecationWarning)
75
            include = keys
76
        self.include, self.exclude = self.parse_include_and_exclude(
77
            include, exclude)
78
79
    def __call__(
80
            self,
81
            data: TypeTransformInput,
82
            ) -> TypeTransformInput:
83
        """Transform data and return a result of the same type.
84
85
        Args:
86
            data: Instance of 1) :class:`~torchio.Subject`, 4D
87
                :class:`torch.Tensor` or :class:`numpy.ndarray` with dimensions
88
                :math:`(C, W, H, D)`, where :math:`C` is the number of channels
89
                and :math:`W, H, D` are the spatial dimensions. If the input is
90
                a tensor, the affine matrix will be set to identity. Other
91
                valid input types are a SimpleITK image, a
92
                :class:`torchio.Image`, a NiBabel Nifti1 image or a
93
                :class:`dict`. The output type is the same as the input type.
94
        """
95
        if torch.rand(1).item() > self.probability:
96
            return data
97
        data_parser = DataParser(data, keys=self.include)
98
        subject = data_parser.get_subject()
99
        if self.copy:
100
            subject = copy.copy(subject)
101
        with np.errstate(all='raise'):
102
            transformed = self.apply_transform(subject)
103
        self.add_transform_to_subject_history(transformed)
104
        for image in transformed.get_images(intensity_only=False):
105
            ndim = image.data.ndim
106
            assert ndim == 4, f'Output of {self.name} is {ndim}D'
107
108
        output = data_parser.get_output(transformed)
109
        return output
110
111
    def __repr__(self):
112
        if hasattr(self, 'args_names'):
113
            names = self.args_names
114
            args_strings = [f'{arg}={getattr(self, arg)}' for arg in names]
115
            if hasattr(self, 'invert_transform') and self.invert_transform:
116
                args_strings.append('invert=True')
117
            args_string = ', '.join(args_strings)
118
            return f'{self.name}({args_string})'
119
        else:
120
            return super().__repr__()
121
122
    @property
123
    def name(self):
124
        return self.__class__.__name__
125
126
    @abstractmethod
127
    def apply_transform(self, subject: Subject):
128
        raise NotImplementedError
129
130
    def add_transform_to_subject_history(self, subject):
131
        from .augmentation import RandomTransform
132
        from . import Compose, OneOf, CropOrPad
133
        from .preprocessing.label import SequentialLabels
134
        call_others = (
135
            RandomTransform,
136
            Compose,
137
            OneOf,
138
            CropOrPad,
139
            SequentialLabels,
140
        )
141
        if not isinstance(self, call_others):
142
            subject.add_transform(self, self._get_reproducing_arguments())
143
144
    @staticmethod
145
    def to_range(n, around):
146
        if around is None:
147
            return 0, n
148
        else:
149
            return around - n, around + n
150
151
    def parse_params(self, params, around, name, make_ranges=True, **kwargs):
152
        params = to_tuple(params)
153
        if len(params) == 1 or (len(params) == 2 and make_ranges):  # d or (a, b)
154
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
155
        if len(params) == 3 and make_ranges:  # (a, b, c)
156
            items = [self.to_range(n, around) for n in params]
157
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
158
            params = [n for prange in items for n in prange]
159
        if make_ranges:
160
            if len(params) != 6:
161
                message = (
162
                    f'If "{name}" is a sequence, it must have length 2, 3 or 6,'
163
                    f' not {len(params)}'
164
                )
165
                raise ValueError(message)
166
            for param_range in zip(params[::2], params[1::2]):
167
                self._parse_range(param_range, name, **kwargs)
168
        return tuple(params)
169
170
    @staticmethod
171
    def _parse_range(
172
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
173
            name: str,
174
            min_constraint: TypeNumber = None,
175
            max_constraint: TypeNumber = None,
176
            type_constraint: type = None,
177
            ) -> Tuple[TypeNumber, TypeNumber]:
178
        r"""Adapted from :class:`torchvision.transforms.RandomRotation`.
179
180
        Args:
181
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
182
                where :math:`n_{min} \leq n_{max}`.
183
                If a single positive number :math:`n` is provided,
184
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
185
            name: Name of the parameter, so that an informative error message
186
                can be printed.
187
            min_constraint: Minimal value that :math:`n_{min}` can take,
188
                default is None, i.e. there is no minimal value.
189
            max_constraint: Maximal value that :math:`n_{max}` can take,
190
                default is None, i.e. there is no maximal value.
191
            type_constraint: Precise type that :math:`n_{max}` and
192
                :math:`n_{min}` must take.
193
194
        Returns:
195
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
196
197
        Raises:
198
            ValueError: if :attr:`nums_range` is negative
199
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
200
            ValueError: if :math:`n_{max} \lt n_{min}`
201
            ValueError: if :attr:`min_constraint` is not None and
202
                :math:`n_{min}` is smaller than :attr:`min_constraint`
203
            ValueError: if :attr:`max_constraint` is not None and
204
                :math:`n_{max}` is greater than :attr:`max_constraint`
205
            ValueError: if :attr:`type_constraint` is not None and
206
                :math:`n_{max}` and :math:`n_{max}` are not of type
207
                :attr:`type_constraint`.
208
        """
209
        if isinstance(nums_range, numbers.Number):  # single number given
210
            if nums_range < 0:
211
                raise ValueError(
212
                    f'If {name} is a single number,'
213
                    f' it must be positive, not {nums_range}')
214
            if min_constraint is not None and nums_range < min_constraint:
215
                raise ValueError(
216
                    f'If {name} is a single number, it must be greater'
217
                    f' than {min_constraint}, not {nums_range}'
218
                )
219
            if max_constraint is not None and nums_range > max_constraint:
220
                raise ValueError(
221
                    f'If {name} is a single number, it must be smaller'
222
                    f' than {max_constraint}, not {nums_range}'
223
                )
224
            if type_constraint is not None:
225
                if not isinstance(nums_range, type_constraint):
226
                    raise ValueError(
227
                        f'If {name} is a single number, it must be of'
228
                        f' type {type_constraint}, not {nums_range}'
229
                    )
230
            min_range = -nums_range if min_constraint is None else nums_range
231
            return (min_range, nums_range)
232
233
        try:
234
            min_value, max_value = nums_range
235
        except (TypeError, ValueError):
236
            raise ValueError(
237
                f'If {name} is not a single number, it must be'
238
                f' a sequence of len 2, not {nums_range}'
239
            )
240
241
        min_is_number = isinstance(min_value, numbers.Number)
242
        max_is_number = isinstance(max_value, numbers.Number)
243
        if not min_is_number or not max_is_number:
244
            message = (
245
                f'{name} values must be numbers, not {nums_range}')
246
            raise ValueError(message)
247
248
        if min_value > max_value:
249
            raise ValueError(
250
                f'If {name} is a sequence, the second value must be'
251
                f' equal or greater than the first, but it is {nums_range}')
252
253
        if min_constraint is not None and min_value < min_constraint:
254
            raise ValueError(
255
                f'If {name} is a sequence, the first value must be greater'
256
                f' than {min_constraint}, but it is {min_value}'
257
            )
258
259
        if max_constraint is not None and max_value > max_constraint:
260
            raise ValueError(
261
                f'If {name} is a sequence, the second value must be smaller'
262
                f' than {max_constraint}, but it is {max_value}'
263
            )
264
265
        if type_constraint is not None:
266
            min_type_ok = isinstance(min_value, type_constraint)
267
            max_type_ok = isinstance(max_value, type_constraint)
268
            if not min_type_ok or not max_type_ok:
269
                raise ValueError(
270
                    f'If "{name}" is a sequence, its values must be of'
271
                    f' type "{type_constraint}", not "{type(nums_range)}"'
272
                )
273
        return nums_range
274
275
    @staticmethod
276
    def parse_interpolation(interpolation: str) -> str:
277
        if not isinstance(interpolation, str):
278
            itype = type(interpolation)
279
            raise TypeError(f'Interpolation must be a string, not {itype}')
280
        interpolation = interpolation.lower()
281
        is_string = isinstance(interpolation, str)
282
        supported_values = [key.name.lower() for key in Interpolation]
283
        is_supported = interpolation.lower() in supported_values
284
        if is_string and is_supported:
285
            return interpolation
286
        message = (
287
            f'Interpolation "{interpolation}" of type {type(interpolation)}'
288
            f' must be a string among the supported values: {supported_values}'
289
        )
290
        raise ValueError(message)
291
292
    @staticmethod
293
    def parse_probability(probability: float) -> float:
294
        is_number = isinstance(probability, numbers.Number)
295
        if not (is_number and 0 <= probability <= 1):
296
            message = (
297
                'Probability must be a number in [0, 1],'
298
                f' not {probability}'
299
            )
300
            raise ValueError(message)
301
        return probability
302
303
    @staticmethod
304
    def parse_include_and_exclude(
305
            include: TypeKeys = None,
306
            exclude: TypeKeys = None,
307
            ) -> Tuple[TypeKeys, TypeKeys]:
308
        if include is not None and exclude is not None:
309
            raise ValueError('Include and exclude cannot both be specified')
310
        return include, exclude
311
312
    @staticmethod
313
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
314
        return nib_to_sitk(data, affine)
315
316
    @staticmethod
317
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
318
        return sitk_to_nib(image)
319
320
    def _get_reproducing_arguments(self):
321
        """
322
        Return a dictionary with the arguments that would be necessary to
323
        reproduce the transform exactly.
324
        """
325
        reproducing_arguments = {'include': self.include, 'exclude': self.exclude, 'copy': self.copy}
326
        reproducing_arguments.update({name: getattr(self, name) for name in self.args_names})
327
        return reproducing_arguments
328
329
    def is_invertible(self):
330
        return hasattr(self, 'invert_transform')
331
332
    def inverse(self):
333
        if not self.is_invertible():
334
            raise RuntimeError(f'{self.name} is not invertible')
335
        new = copy.deepcopy(self)
336
        new.invert_transform = not self.invert_transform
337
        return new
338
339
    @staticmethod
340
    @contextmanager
341
    def _use_seed(seed):
342
        """Perform an operation using a specific seed for the PyTorch RNG"""
343
        torch_rng_state = torch.random.get_rng_state()
344
        torch.manual_seed(seed)
345
        yield
346
        torch.random.set_rng_state(torch_rng_state)
347
348
    @staticmethod
349
    def get_sitk_interpolator(interpolation: str) -> int:
350
        return get_sitk_interpolator(interpolation)
351
352
    @staticmethod
353
    def parse_bounds(bounds_parameters: TypeBounds) -> TypeSixBounds:
354
        try:
355
            bounds_parameters = tuple(bounds_parameters)
356
        except TypeError:
357
            bounds_parameters = (bounds_parameters,)
358
359
        # Check that numbers are integers
360
        for number in bounds_parameters:
361
            if not isinstance(number, (int, np.integer)) or number < 0:
362
                message = (
363
                    'Bounds values must be integers greater or equal to zero,'
364
                    f' not "{bounds_parameters}" of type {type(number)}'
365
                )
366
                raise ValueError(message)
367
        bounds_parameters = tuple(int(n) for n in bounds_parameters)
368
        bounds_parameters_length = len(bounds_parameters)
369
        if bounds_parameters_length == 6:
370
            return bounds_parameters
371
        if bounds_parameters_length == 1:
372
            return 6 * bounds_parameters
373
        if bounds_parameters_length == 3:
374
            return tuple(np.repeat(bounds_parameters, 2).tolist())
375
        message = (
376
            'Bounds parameter must be an integer or a tuple of'
377
            f' 3 or 6 integers, not {bounds_parameters}'
378
        )
379
        raise ValueError(message)
380
381
    @staticmethod
382
    def ones(tensor: torch.Tensor) -> torch.Tensor:
383
        return torch.ones_like(tensor, dtype=torch.bool)
384
385
    @staticmethod
386
    def mean(tensor: torch.Tensor) -> torch.Tensor:
387
        mask = tensor > tensor.mean()
388
        return mask
389
390
    @staticmethod
391
    def get_mask(masking_method: TypeMaskingMethod, subject: Subject, tensor: torch.Tensor) -> torch.Tensor:
392
        if masking_method is None:
393
            return Transform.ones(tensor)
394
        elif callable(masking_method):
395
            return masking_method(tensor)
396
        elif type(masking_method) is str:
397
            if masking_method in subject and isinstance(subject[masking_method], LabelMap):
398
                return subject[masking_method].data.bool()
399
            if masking_method.title() in ('Left', 'Right', 'Anterior', 'Posterior', 'Inferior', 'Superior'):
400
                return Transform.get_mask_from_anatomical_label(masking_method.title(), tensor)
401
        elif type(masking_method) in (tuple, list, int):
402
            return Transform.get_mask_from_bounds(masking_method, tensor)
403
        message = (
404
            'Masking method parameter must be a function, a label map name,'
405
            " an anatomical label: ('Left', 'Right', 'Anterior', 'Posterior', 'Inferior', 'Superior'),"
406
            ' or a bounds parameter: (an int, tuple of 3 ints, or tuple of 6 ints)'
407
            f' not {masking_method} of type {type(masking_method)}'
408
        )
409
        raise ValueError(message)
410
411
    @staticmethod
412
    def get_mask_from_anatomical_label(anatomical_label: str, tensor: torch.Tensor) -> torch.Tensor:
413
        anatomical_label = anatomical_label.title()
414
        if anatomical_label.title() not in ('Left', 'Right', 'Anterior', 'Posterior', 'Inferior', 'Superior'):
415
            message = (
416
                "Anatomical label must be one of ('Left', 'Right', 'Anterior', 'Posterior', 'Inferior', 'Superior')"
417
                f' not {anatomical_label}'
418
            )
419
            raise ValueError(message)
420
        mask = torch.zeros_like(tensor, dtype=torch.bool)
421
        _, width, height, depth = tensor.shape
422
        if anatomical_label == 'Right':
423
            mask[:, width // 2:] = True
424
        elif anatomical_label == 'Left':
425
            mask[:, :width // 2] = True
426
        elif anatomical_label == 'Anterior':
427
            mask[:, :, height // 2:] = True
428
        elif anatomical_label == 'Posterior':
429
            mask[:, :, :height // 2] = True
430
        elif anatomical_label == 'Superior':
431
            mask[:, :, :, depth // 2:] = True
432
        elif anatomical_label == 'Inferior':
433
            mask[:, :, :, :depth // 2] = True
434
        return mask
435
436
    @staticmethod
437
    def get_mask_from_bounds(bounds_parameters: TypeBounds, tensor: torch.Tensor) -> torch.Tensor:
438
        bounds_parameters = Transform.parse_bounds(bounds_parameters)
439
        low = bounds_parameters[::2]
440
        high = bounds_parameters[1::2]
441
        i0, j0, k0 = low
442
        i1, j1, k1 = np.array(tensor.shape[1:]) - high
443
        mask = torch.zeros_like(tensor, dtype=torch.bool)
444
        mask[:, i0:i1, j0:j1, k0:k1] = True
445
        return mask
446