Passed
Pull Request — master (#625)
by Fernando
01:52
created

Transform.parse_axes()   A

Complexity

Conditions 4

Size

Total Lines 15
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 12
dl 0
loc 15
rs 9.8
c 0
b 0
f 0
cc 4
nop 1
1
import copy
2
import numbers
3
import warnings
4
from abc import ABC, abstractmethod
5
from contextlib import contextmanager
6
from typing import Union, Tuple, Optional, Dict, Sequence
7
8
import torch
9
import numpy as np
10
import SimpleITK as sitk
11
12
from ..utils import to_tuple
13
from ..data.subject import Subject
14
from ..data.io import nib_to_sitk, sitk_to_nib
15
from ..data.image import LabelMap
16
from ..typing import (
17
    TypeKeys,
18
    TypeData,
19
    TypeNumber,
20
    TypeCallable,
21
    TypeTripletInt,
22
)
23
from .interpolation import Interpolation, get_sitk_interpolator
24
from .data_parser import DataParser, TypeTransformInput
25
26
TypeSixBounds = Tuple[int, int, int, int, int, int]
27
TypeBounds = Union[
28
    int,
29
    TypeTripletInt,
30
    TypeSixBounds,
31
]
32
TypeMaskingMethod = Union[str, TypeCallable, TypeBounds, None]
33
ANATOMICAL_AXES = (
34
    'Left', 'Right',
35
    'Posterior', 'Anterior',
36
    'Inferior', 'Superior',
37
)
38
39
40
class Transform(ABC):
41
    """Abstract class for all TorchIO transforms.
42
43
    When called, the input can be an instance of
44
    :class:`torchio.Subject`,
45
    :class:`torchio.Image`,
46
    :class:`numpy.ndarray`,
47
    :class:`torch.Tensor`,
48
    :class:`SimpleITK.Image`,
49
    or :class:`dict` containing 4D tensors as values.
50
51
    All subclasses must overwrite
52
    :meth:`~torchio.transforms.Transform.apply_transform`,
53
    which takes an instance of :class:`~torchio.Subject`,
54
    modifies it and returns the result.
55
56
    Args:
57
        p: Probability that this transform will be applied.
58
        copy: Make a shallow copy of the input before applying the transform.
59
        include: Sequence of strings with the names of the only images to which
60
            the transform will be applied.
61
            Mandatory if the input is a :class:`dict`.
62
        exclude: Sequence of strings with the names of the images to which the
63
            the transform will not be applied, apart from the ones that are
64
            excluded because of the transform type.
65
            For example, if a subject includes an MRI, a CT and a label map,
66
            and the CT is added to the list of exclusions of an intensity
67
            transform such as :class:`~torchio.transforms.RandomBlur`,
68
            the transform will be only applied to the MRI, as the label map is
69
            excluded by default by spatial transforms.
70
        keep: Dictionary with the names of the images that will be kept in the
71
            subject and their new names.
72
    """
73
    def __init__(
74
            self,
75
            p: float = 1,
76
            copy: bool = True,
77
            include: TypeKeys = None,
78
            exclude: TypeKeys = None,
79
            keys: TypeKeys = None,
80
            keep: Optional[Dict[str, str]] = None,
81
            ):
82
        self.probability = self.parse_probability(p)
83
        self.copy = copy
84
        if keys is not None:
85
            message = (
86
                'The "keys" argument is deprecated and will be removed in the'
87
                ' future. Use "include" instead.'
88
            )
89
            warnings.warn(message)
90
            include = keys
91
        self.include, self.exclude = self.parse_include_and_exclude(
92
            include, exclude)
93
        self.keep = keep
94
        # args_names is the sequence of parameters from self that need to be
95
        # passed to a non-random version of a random transform. They are also
96
        # used to invert invertible transforms
97
        self.args_names = ()
98
99
    def __call__(
100
            self,
101
            data: TypeTransformInput,
102
            ) -> TypeTransformInput:
103
        """Transform data and return a result of the same type.
104
105
        Args:
106
            data: Instance of :class:`torchio.Subject`, 4D
107
                :class:`torch.Tensor` or :class:`numpy.ndarray` with dimensions
108
                :math:`(C, W, H, D)`, where :math:`C` is the number of channels
109
                and :math:`W, H, D` are the spatial dimensions. If the input is
110
                a tensor, the affine matrix will be set to identity. Other
111
                valid input types are a SimpleITK image, a
112
                :class:`torchio.Image`, a NiBabel Nifti1 image or a
113
                :class:`dict`. The output type is the same as the input type.
114
        """
115
        if torch.rand(1).item() > self.probability:
116
            return data
117
        data_parser = DataParser(data, keys=self.include)
118
        subject = data_parser.get_subject()
119
        if self.keep is not None:
120
            images_to_keep = {}
121
            for name, new_name in self.keep.items():
122
                images_to_keep[new_name] = copy.copy(subject[name])
123
        if self.copy:
124
            subject = copy.copy(subject)
125
        with np.errstate(all='raise', under='ignore'):
126
            transformed = self.apply_transform(subject)
127
        if self.keep is not None:
128
            for name, image in images_to_keep.items():
129
                transformed.add_image(image, name)
130
        self.add_transform_to_subject_history(transformed)
131
        for image in transformed.get_images(intensity_only=False):
132
            ndim = image.data.ndim
133
            assert ndim == 4, f'Output of {self.name} is {ndim}D'
134
135
        output = data_parser.get_output(transformed)
136
        return output
137
138
    def __repr__(self):
139
        if hasattr(self, 'args_names'):
140
            names = self.args_names
141
            args_strings = [f'{arg}={getattr(self, arg)}' for arg in names]
142
            if hasattr(self, 'invert_transform') and self.invert_transform:
143
                args_strings.append('invert=True')
144
            args_string = ', '.join(args_strings)
145
            return f'{self.name}({args_string})'
146
        else:
147
            return super().__repr__()
148
149
    @property
150
    def name(self):
151
        return self.__class__.__name__
152
153
    @abstractmethod
154
    def apply_transform(self, subject: Subject) -> Subject:
155
        raise NotImplementedError
156
157
    def add_transform_to_subject_history(self, subject):
158
        from .augmentation import RandomTransform
159
        from . import Compose, OneOf, CropOrPad, EnsureShapeMultiple
160
        from .preprocessing import SequentialLabels
161
        call_others = (
162
            RandomTransform,
163
            Compose,
164
            OneOf,
165
            CropOrPad,
166
            EnsureShapeMultiple,
167
            SequentialLabels,
168
        )
169
        if not isinstance(self, call_others):
170
            subject.add_transform(self, self._get_reproducing_arguments())
171
172
    @staticmethod
173
    def to_range(n, around):
174
        if around is None:
175
            return 0, n
176
        else:
177
            return around - n, around + n
178
179
    def parse_params(self, params, around, name, make_ranges=True, **kwargs):
180
        params = to_tuple(params)
181
        # d or (a, b)
182
        if len(params) == 1 or (len(params) == 2 and make_ranges):
183
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
184
        if len(params) == 3 and make_ranges:  # (a, b, c)
185
            items = [self.to_range(n, around) for n in params]
186
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
187
            params = [n for prange in items for n in prange]
188
        if make_ranges:
189
            if len(params) != 6:
190
                message = (
191
                    f'If "{name}" is a sequence, it must have length 2, 3 or'
192
                    f' 6, not {len(params)}'
193
                )
194
                raise ValueError(message)
195
            for param_range in zip(params[::2], params[1::2]):
196
                self._parse_range(param_range, name, **kwargs)
197
        return tuple(params)
198
199
    @staticmethod
200
    def _parse_range(
201
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
202
            name: str,
203
            min_constraint: TypeNumber = None,
204
            max_constraint: TypeNumber = None,
205
            type_constraint: type = None,
206
            ) -> Tuple[TypeNumber, TypeNumber]:
207
        r"""Adapted from :class:`torchvision.transforms.RandomRotation`.
208
209
        Args:
210
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
211
                where :math:`n_{min} \leq n_{max}`.
212
                If a single positive number :math:`n` is provided,
213
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
214
            name: Name of the parameter, so that an informative error message
215
                can be printed.
216
            min_constraint: Minimal value that :math:`n_{min}` can take,
217
                default is None, i.e. there is no minimal value.
218
            max_constraint: Maximal value that :math:`n_{max}` can take,
219
                default is None, i.e. there is no maximal value.
220
            type_constraint: Precise type that :math:`n_{max}` and
221
                :math:`n_{min}` must take.
222
223
        Returns:
224
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
225
226
        Raises:
227
            ValueError: if :attr:`nums_range` is negative
228
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
229
            ValueError: if :math:`n_{max} \lt n_{min}`
230
            ValueError: if :attr:`min_constraint` is not None and
231
                :math:`n_{min}` is smaller than :attr:`min_constraint`
232
            ValueError: if :attr:`max_constraint` is not None and
233
                :math:`n_{max}` is greater than :attr:`max_constraint`
234
            ValueError: if :attr:`type_constraint` is not None and
235
                :math:`n_{max}` and :math:`n_{max}` are not of type
236
                :attr:`type_constraint`.
237
        """
238
        if isinstance(nums_range, numbers.Number):  # single number given
239
            if nums_range < 0:
240
                raise ValueError(
241
                    f'If {name} is a single number,'
242
                    f' it must be positive, not {nums_range}')
243
            if min_constraint is not None and nums_range < min_constraint:
244
                raise ValueError(
245
                    f'If {name} is a single number, it must be greater'
246
                    f' than {min_constraint}, not {nums_range}'
247
                )
248
            if max_constraint is not None and nums_range > max_constraint:
249
                raise ValueError(
250
                    f'If {name} is a single number, it must be smaller'
251
                    f' than {max_constraint}, not {nums_range}'
252
                )
253
            if type_constraint is not None:
254
                if not isinstance(nums_range, type_constraint):
255
                    raise ValueError(
256
                        f'If {name} is a single number, it must be of'
257
                        f' type {type_constraint}, not {nums_range}'
258
                    )
259
            min_range = -nums_range if min_constraint is None else nums_range
260
            return (min_range, nums_range)
261
262
        try:
263
            min_value, max_value = nums_range
264
        except (TypeError, ValueError):
265
            raise ValueError(
266
                f'If {name} is not a single number, it must be'
267
                f' a sequence of len 2, not {nums_range}'
268
            )
269
270
        min_is_number = isinstance(min_value, numbers.Number)
271
        max_is_number = isinstance(max_value, numbers.Number)
272
        if not min_is_number or not max_is_number:
273
            message = (
274
                f'{name} values must be numbers, not {nums_range}')
275
            raise ValueError(message)
276
277
        if min_value > max_value:
278
            raise ValueError(
279
                f'If {name} is a sequence, the second value must be'
280
                f' equal or greater than the first, but it is {nums_range}')
281
282
        if min_constraint is not None and min_value < min_constraint:
283
            raise ValueError(
284
                f'If {name} is a sequence, the first value must be greater'
285
                f' than {min_constraint}, but it is {min_value}'
286
            )
287
288
        if max_constraint is not None and max_value > max_constraint:
289
            raise ValueError(
290
                f'If {name} is a sequence, the second value must be smaller'
291
                f' than {max_constraint}, but it is {max_value}'
292
            )
293
294
        if type_constraint is not None:
295
            min_type_ok = isinstance(min_value, type_constraint)
296
            max_type_ok = isinstance(max_value, type_constraint)
297
            if not min_type_ok or not max_type_ok:
298
                raise ValueError(
299
                    f'If "{name}" is a sequence, its values must be of'
300
                    f' type "{type_constraint}", not "{type(nums_range)}"'
301
                )
302
        return nums_range
303
304
    @staticmethod
305
    def parse_interpolation(interpolation: str) -> str:
306
        if not isinstance(interpolation, str):
307
            itype = type(interpolation)
308
            raise TypeError(f'Interpolation must be a string, not {itype}')
309
        interpolation = interpolation.lower()
310
        is_string = isinstance(interpolation, str)
311
        supported_values = [key.name.lower() for key in Interpolation]
312
        is_supported = interpolation.lower() in supported_values
313
        if is_string and is_supported:
314
            return interpolation
315
        message = (
316
            f'Interpolation "{interpolation}" of type {type(interpolation)}'
317
            f' must be a string among the supported values: {supported_values}'
318
        )
319
        raise ValueError(message)
320
321
    @staticmethod
322
    def parse_probability(probability: float) -> float:
323
        is_number = isinstance(probability, numbers.Number)
324
        if not (is_number and 0 <= probability <= 1):
325
            message = (
326
                'Probability must be a number in [0, 1],'
327
                f' not {probability}'
328
            )
329
            raise ValueError(message)
330
        return probability
331
332
    @staticmethod
333
    def parse_include_and_exclude(
334
            include: TypeKeys = None,
335
            exclude: TypeKeys = None,
336
            ) -> Tuple[TypeKeys, TypeKeys]:
337
        if include is not None and exclude is not None:
338
            raise ValueError('Include and exclude cannot both be specified')
339
        return include, exclude
340
341
    @staticmethod
342
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
343
        return nib_to_sitk(data, affine)
344
345
    @staticmethod
346
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
347
        return sitk_to_nib(image)
348
349
    def _get_reproducing_arguments(self):
350
        """
351
        Return a dictionary with the arguments that would be necessary to
352
        reproduce the transform exactly.
353
        """
354
        reproducing_arguments = {
355
            'include': self.include,
356
            'exclude': self.exclude,
357
            'copy': self.copy,
358
        }
359
        args_names = {name: getattr(self, name) for name in self.args_names}
360
        reproducing_arguments.update(args_names)
361
        return reproducing_arguments
362
363
    def is_invertible(self):
364
        return hasattr(self, 'invert_transform')
365
366
    def inverse(self):
367
        if not self.is_invertible():
368
            raise RuntimeError(f'{self.name} is not invertible')
369
        new = copy.deepcopy(self)
370
        new.invert_transform = not self.invert_transform
371
        return new
372
373
    @staticmethod
374
    @contextmanager
375
    def _use_seed(seed):
376
        """Perform an operation using a specific seed for the PyTorch RNG"""
377
        torch_rng_state = torch.random.get_rng_state()
378
        torch.manual_seed(seed)
379
        yield
380
        torch.random.set_rng_state(torch_rng_state)
381
382
    @staticmethod
383
    def get_sitk_interpolator(interpolation: str) -> int:
384
        return get_sitk_interpolator(interpolation)
385
386
    @staticmethod
387
    def parse_bounds(bounds_parameters: TypeBounds) -> TypeSixBounds:
388
        try:
389
            bounds_parameters = tuple(bounds_parameters)
390
        except TypeError:
391
            bounds_parameters = (bounds_parameters,)
392
393
        # Check that numbers are integers
394
        for number in bounds_parameters:
395
            if not isinstance(number, (int, np.integer)) or number < 0:
396
                message = (
397
                    'Bounds values must be integers greater or equal to zero,'
398
                    f' not "{bounds_parameters}" of type {type(number)}'
399
                )
400
                raise ValueError(message)
401
        bounds_parameters = tuple(int(n) for n in bounds_parameters)
402
        bounds_parameters_length = len(bounds_parameters)
403
        if bounds_parameters_length == 6:
404
            return bounds_parameters
405
        if bounds_parameters_length == 1:
406
            return 6 * bounds_parameters
407
        if bounds_parameters_length == 3:
408
            return tuple(np.repeat(bounds_parameters, 2).tolist())
409
        message = (
410
            'Bounds parameter must be an integer or a tuple of'
411
            f' 3 or 6 integers, not {bounds_parameters}'
412
        )
413
        raise ValueError(message)
414
415
    @staticmethod
416
    def ones(tensor: torch.Tensor) -> torch.Tensor:
417
        return torch.ones_like(tensor, dtype=torch.bool)
418
419
    @staticmethod
420
    def mean(tensor: torch.Tensor) -> torch.Tensor:
421
        mask = tensor > tensor.float().mean()
422
        return mask
423
424
    def get_mask_from_masking_method(
425
            self,
426
            masking_method: TypeMaskingMethod,
427
            subject: Subject,
428
            tensor: torch.Tensor,
429
            labels: list = None,
430
            ) -> torch.Tensor:
431
        if masking_method is None:
432
            return self.ones(tensor)
433
        elif callable(masking_method):
434
            return masking_method(tensor)
435
        elif type(masking_method) is str:
436
            in_subject = masking_method in subject
437
            if in_subject and isinstance(subject[masking_method], LabelMap):
438
                if labels is None:
439
                    return subject[masking_method].data.bool()
440
                else:
441
                    mask_data = subject[masking_method].data
442
                    volumes = [mask_data == label for label in labels]
443
                    return torch.stack(volumes).sum(0).bool()
444
            possible_axis = masking_method.capitalize()
445
            if possible_axis in ANATOMICAL_AXES:
446
                return self.get_mask_from_anatomical_label(
447
                    possible_axis, tensor)
448
        elif type(masking_method) in (tuple, list, int):
449
            return self.get_mask_from_bounds(masking_method, tensor)
450
        first_anat_axes = tuple(s[0] for s in ANATOMICAL_AXES)
451
        message = (
452
            'Masking method must be one of:\n'
453
            ' 1) A callable object, such as a function\n'
454
            ' 2) The name of a label map in the subject'
455
            f' ({subject.get_images_names()})\n'
456
            f' 3) An anatomical label {ANATOMICAL_AXES + first_anat_axes}\n'
457
            ' 4) A bounds parameter'
458
            ' (int, tuple of 3 ints, or tuple of 6 ints)\n'
459
            f' The passed value, "{masking_method}",'
460
            f' of type "{type(masking_method)}", is not valid'
461
        )
462
        raise ValueError(message)
463
464
    @staticmethod
465
    def get_mask_from_anatomical_label(
466
            anatomical_label: str,
467
            tensor: torch.Tensor,
468
            ) -> torch.Tensor:
469
        # Assume the image is in RAS orientation
470
        anatomical_label = anatomical_label.capitalize()
471
        if anatomical_label not in ANATOMICAL_AXES:
472
            message = (
473
                f'Anatomical label must be one of {ANATOMICAL_AXES}'
474
                f' not {anatomical_label}'
475
            )
476
            raise ValueError(message)
477
        mask = torch.zeros_like(tensor, dtype=torch.bool)
478
        _, width, height, depth = tensor.shape
479
        if anatomical_label == 'Right':
480
            mask[:, width // 2:] = True
481
        elif anatomical_label == 'Left':
482
            mask[:, :width // 2] = True
483
        elif anatomical_label == 'Anterior':
484
            mask[:, :, height // 2:] = True
485
        elif anatomical_label == 'Posterior':
486
            mask[:, :, :height // 2] = True
487
        elif anatomical_label == 'Superior':
488
            mask[:, :, :, depth // 2:] = True
489
        elif anatomical_label == 'Inferior':
490
            mask[:, :, :, :depth // 2] = True
491
        return mask
492
493
    def get_mask_from_bounds(
494
            self,
495
            bounds_parameters: TypeBounds,
496
            tensor: torch.Tensor,
497
            ) -> torch.Tensor:
498
        bounds_parameters = self.parse_bounds(bounds_parameters)
499
        low = bounds_parameters[::2]
500
        high = bounds_parameters[1::2]
501
        i0, j0, k0 = low
502
        i1, j1, k1 = np.array(tensor.shape[1:]) - high
503
        mask = torch.zeros_like(tensor, dtype=torch.bool)
504
        mask[:, i0:i1, j0:j1, k0:k1] = True
505
        return mask
506
507
    @staticmethod
508
    def parse_axes(axes: Union[int, Sequence[int]]) -> Tuple[int, ...]:
509
        """Ensure that all values in sequence are in [0, 1, 2] or strings."""
510
        axes_tuple = to_tuple(axes)
511
        for axis in axes_tuple:
512
            is_int = isinstance(axis, int)
513
            is_string = isinstance(axis, str)
514
            valid_number = is_int and axis in (0, 1, 2)
515
            if not is_string and not valid_number:
516
                message = (
517
                    f'All axes must be 0, 1 or 2, but found "{axis}"'
518
                    f' with type {type(axis)}'
519
                )
520
                raise ValueError(message)
521
        return axes_tuple
522
523
    @staticmethod
524
    def ensure_axes_indices(
525
            subject: Subject,
526
            axes: Sequence[int],
527
            ) -> Sequence[int]:
528
        if any(isinstance(n, str) for n in axes):
529
            subject.check_consistent_orientation()
530
            image = subject.get_first_image()
531
            axes = sorted(3 + image.axis_name_to_index(n) for n in axes)
532
        return axes
533