Passed
Push — master ( db45b7...6deb01 )
by Fernando
01:30
created

Transform.parse_bounds()   C

Complexity

Conditions 9

Size

Total Lines 30
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 24
dl 0
loc 30
rs 6.6666
c 0
b 0
f 0
cc 9
nop 1
1
import copy
2
import numbers
3
import warnings
4
from abc import ABC, abstractmethod
5
from contextlib import contextmanager
6
from typing import Union, Tuple, Optional, Dict
7
8
import torch
9
import numpy as np
10
import SimpleITK as sitk
11
12
from ..utils import to_tuple
13
from ..data.subject import Subject
14
from ..data.io import nib_to_sitk, sitk_to_nib
15
from ..data.image import LabelMap
16
from ..typing import (
17
    TypeKeys,
18
    TypeData,
19
    TypeNumber,
20
    TypeCallable,
21
    TypeTripletInt,
22
)
23
from .interpolation import Interpolation, get_sitk_interpolator
24
from .data_parser import DataParser, TypeTransformInput
25
26
TypeSixBounds = Tuple[int, int, int, int, int, int]
27
TypeBounds = Union[
28
    int,
29
    TypeTripletInt,
30
    TypeSixBounds,
31
    None,
32
]
33
TypeMaskingMethod = Union[str, TypeCallable, TypeBounds, None]
34
ANATOMICAL_AXES = (
35
    'Left', 'Right',
36
    'Posterior', 'Anterior',
37
    'Inferior', 'Superior',
38
)
39
40
41
class Transform(ABC):
42
    """Abstract class for all TorchIO transforms.
43
44
    When called, the input can be an instance of
45
    :class:`torchio.Subject`,
46
    :class:`torchio.Image`,
47
    :class:`numpy.ndarray`,
48
    :class:`torch.Tensor`,
49
    :class:`SimpleITK.Image`,
50
    or :class:`dict` containing 4D tensors as values.
51
52
    All subclasses must overwrite
53
    :meth:`~torchio.transforms.Transform.apply_transform`,
54
    which takes an instance of :class:`~torchio.Subject`,
55
    modifies it and returns the result.
56
57
    Args:
58
        p: Probability that this transform will be applied.
59
        copy: Make a shallow copy of the input before applying the transform.
60
        include: Sequence of strings with the names of the only images to which
61
            the transform will be applied.
62
            Mandatory if the input is a :class:`dict`.
63
        exclude: Sequence of strings with the names of the images to which the
64
            the transform will not be applied, apart from the ones that are
65
            excluded because of the transform type.
66
            For example, if a subject includes an MRI, a CT and a label map,
67
            and the CT is added to the list of exclusions of an intensity
68
            transform such as :class:`~torchio.transforms.RandomBlur`,
69
            the transform will be only applied to the MRI, as the label map is
70
            excluded by default by spatial transforms.
71
        keep: Dictionary with the names of the images that will be kept in the
72
            subject and their new names.
73
    """
74
    def __init__(
75
            self,
76
            p: float = 1,
77
            copy: bool = True,
78
            include: TypeKeys = None,
79
            exclude: TypeKeys = None,
80
            keys: TypeKeys = None,
81
            keep: Optional[Dict[str, str]] = None,
82
            ):
83
        self.probability = self.parse_probability(p)
84
        self.copy = copy
85
        if keys is not None:
86
            message = (
87
                'The "keys" argument is deprecated and will be removed in the'
88
                ' future. Use "include" instead.'
89
            )
90
            warnings.warn(message)
91
            include = keys
92
        self.include, self.exclude = self.parse_include_and_exclude(
93
            include, exclude)
94
        self.keep = keep
95
        # args_names is the sequence of parameters from self that need to be
96
        # passed to a non-random version of a random transform. They are also
97
        # used to invert invertible transforms
98
        self.args_names = ()
99
100
    def __call__(
101
            self,
102
            data: TypeTransformInput,
103
            ) -> TypeTransformInput:
104
        """Transform data and return a result of the same type.
105
106
        Args:
107
            data: Instance of :class:`torchio.Subject`, 4D
108
                :class:`torch.Tensor` or :class:`numpy.ndarray` with dimensions
109
                :math:`(C, W, H, D)`, where :math:`C` is the number of channels
110
                and :math:`W, H, D` are the spatial dimensions. If the input is
111
                a tensor, the affine matrix will be set to identity. Other
112
                valid input types are a SimpleITK image, a
113
                :class:`torchio.Image`, a NiBabel Nifti1 image or a
114
                :class:`dict`. The output type is the same as the input type.
115
        """
116
        if torch.rand(1).item() > self.probability:
117
            return data
118
        data_parser = DataParser(data, keys=self.include)
119
        subject = data_parser.get_subject()
120
        if self.keep is not None:
121
            images_to_keep = {}
122
            for name, new_name in self.keep.items():
123
                images_to_keep[new_name] = copy.copy(subject[name])
124
        if self.copy:
125
            subject = copy.copy(subject)
126
        with np.errstate(all='raise', under='ignore'):
127
            transformed = self.apply_transform(subject)
128
        if self.keep is not None:
129
            for name, image in images_to_keep.items():
0 ignored issues
show
introduced by
The variable images_to_keep does not seem to be defined in case self.keep is not None on line 120 is False. Are you sure this can never be the case?
Loading history...
130
                transformed.add_image(image, name)
131
        self.add_transform_to_subject_history(transformed)
132
        for image in transformed.get_images(intensity_only=False):
133
            ndim = image.data.ndim
134
            assert ndim == 4, f'Output of {self.name} is {ndim}D'
135
136
        output = data_parser.get_output(transformed)
137
        return output
138
139
    def __repr__(self):
140
        if hasattr(self, 'args_names'):
141
            names = self.args_names
142
            args_strings = [f'{arg}={getattr(self, arg)}' for arg in names]
143
            if hasattr(self, 'invert_transform') and self.invert_transform:
144
                args_strings.append('invert=True')
145
            args_string = ', '.join(args_strings)
146
            return f'{self.name}({args_string})'
147
        else:
148
            return super().__repr__()
149
150
    @property
151
    def name(self):
152
        return self.__class__.__name__
153
154
    @abstractmethod
155
    def apply_transform(self, subject: Subject) -> Subject:
156
        raise NotImplementedError
157
158
    def add_transform_to_subject_history(self, subject):
159
        from .augmentation import RandomTransform
160
        from . import Compose, OneOf, CropOrPad, EnsureShapeMultiple
161
        from .preprocessing import SequentialLabels
162
        call_others = (
163
            RandomTransform,
164
            Compose,
165
            OneOf,
166
            CropOrPad,
167
            EnsureShapeMultiple,
168
            SequentialLabels,
169
        )
170
        if not isinstance(self, call_others):
171
            subject.add_transform(self, self._get_reproducing_arguments())
172
173
    @staticmethod
174
    def to_range(n, around):
175
        if around is None:
176
            return 0, n
177
        else:
178
            return around - n, around + n
179
180
    def parse_params(self, params, around, name, make_ranges=True, **kwargs):
181
        params = to_tuple(params)
182
        # d or (a, b)
183
        if len(params) == 1 or (len(params) == 2 and make_ranges):
184
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
185
        if len(params) == 3 and make_ranges:  # (a, b, c)
186
            items = [self.to_range(n, around) for n in params]
187
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
188
            params = [n for prange in items for n in prange]
189
        if make_ranges:
190
            if len(params) != 6:
191
                message = (
192
                    f'If "{name}" is a sequence, it must have length 2, 3 or'
193
                    f' 6, not {len(params)}'
194
                )
195
                raise ValueError(message)
196
            for param_range in zip(params[::2], params[1::2]):
197
                self._parse_range(param_range, name, **kwargs)
198
        return tuple(params)
199
200
    @staticmethod
201
    def _parse_range(
202
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
203
            name: str,
204
            min_constraint: TypeNumber = None,
205
            max_constraint: TypeNumber = None,
206
            type_constraint: type = None,
207
            ) -> Tuple[TypeNumber, TypeNumber]:
208
        r"""Adapted from :class:`torchvision.transforms.RandomRotation`.
209
210
        Args:
211
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
212
                where :math:`n_{min} \leq n_{max}`.
213
                If a single positive number :math:`n` is provided,
214
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
215
            name: Name of the parameter, so that an informative error message
216
                can be printed.
217
            min_constraint: Minimal value that :math:`n_{min}` can take,
218
                default is None, i.e. there is no minimal value.
219
            max_constraint: Maximal value that :math:`n_{max}` can take,
220
                default is None, i.e. there is no maximal value.
221
            type_constraint: Precise type that :math:`n_{max}` and
222
                :math:`n_{min}` must take.
223
224
        Returns:
225
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
226
227
        Raises:
228
            ValueError: if :attr:`nums_range` is negative
229
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
230
            ValueError: if :math:`n_{max} \lt n_{min}`
231
            ValueError: if :attr:`min_constraint` is not None and
232
                :math:`n_{min}` is smaller than :attr:`min_constraint`
233
            ValueError: if :attr:`max_constraint` is not None and
234
                :math:`n_{max}` is greater than :attr:`max_constraint`
235
            ValueError: if :attr:`type_constraint` is not None and
236
                :math:`n_{max}` and :math:`n_{max}` are not of type
237
                :attr:`type_constraint`.
238
        """
239
        if isinstance(nums_range, numbers.Number):  # single number given
240
            if nums_range < 0:
241
                raise ValueError(
242
                    f'If {name} is a single number,'
243
                    f' it must be positive, not {nums_range}')
244
            if min_constraint is not None and nums_range < min_constraint:
245
                raise ValueError(
246
                    f'If {name} is a single number, it must be greater'
247
                    f' than {min_constraint}, not {nums_range}'
248
                )
249
            if max_constraint is not None and nums_range > max_constraint:
250
                raise ValueError(
251
                    f'If {name} is a single number, it must be smaller'
252
                    f' than {max_constraint}, not {nums_range}'
253
                )
254
            if type_constraint is not None:
255
                if not isinstance(nums_range, type_constraint):
256
                    raise ValueError(
257
                        f'If {name} is a single number, it must be of'
258
                        f' type {type_constraint}, not {nums_range}'
259
                    )
260
            min_range = -nums_range if min_constraint is None else nums_range
261
            return (min_range, nums_range)
262
263
        try:
264
            min_value, max_value = nums_range
265
        except (TypeError, ValueError):
266
            raise ValueError(
267
                f'If {name} is not a single number, it must be'
268
                f' a sequence of len 2, not {nums_range}'
269
            )
270
271
        min_is_number = isinstance(min_value, numbers.Number)
272
        max_is_number = isinstance(max_value, numbers.Number)
273
        if not min_is_number or not max_is_number:
274
            message = (
275
                f'{name} values must be numbers, not {nums_range}')
276
            raise ValueError(message)
277
278
        if min_value > max_value:
279
            raise ValueError(
280
                f'If {name} is a sequence, the second value must be'
281
                f' equal or greater than the first, but it is {nums_range}')
282
283
        if min_constraint is not None and min_value < min_constraint:
284
            raise ValueError(
285
                f'If {name} is a sequence, the first value must be greater'
286
                f' than {min_constraint}, but it is {min_value}'
287
            )
288
289
        if max_constraint is not None and max_value > max_constraint:
290
            raise ValueError(
291
                f'If {name} is a sequence, the second value must be smaller'
292
                f' than {max_constraint}, but it is {max_value}'
293
            )
294
295
        if type_constraint is not None:
296
            min_type_ok = isinstance(min_value, type_constraint)
297
            max_type_ok = isinstance(max_value, type_constraint)
298
            if not min_type_ok or not max_type_ok:
299
                raise ValueError(
300
                    f'If "{name}" is a sequence, its values must be of'
301
                    f' type "{type_constraint}", not "{type(nums_range)}"'
302
                )
303
        return nums_range
304
305
    @staticmethod
306
    def parse_interpolation(interpolation: str) -> str:
307
        if not isinstance(interpolation, str):
308
            itype = type(interpolation)
309
            raise TypeError(f'Interpolation must be a string, not {itype}')
310
        interpolation = interpolation.lower()
311
        is_string = isinstance(interpolation, str)
312
        supported_values = [key.name.lower() for key in Interpolation]
313
        is_supported = interpolation.lower() in supported_values
314
        if is_string and is_supported:
315
            return interpolation
316
        message = (
317
            f'Interpolation "{interpolation}" of type {type(interpolation)}'
318
            f' must be a string among the supported values: {supported_values}'
319
        )
320
        raise ValueError(message)
321
322
    @staticmethod
323
    def parse_probability(probability: float) -> float:
324
        is_number = isinstance(probability, numbers.Number)
325
        if not (is_number and 0 <= probability <= 1):
326
            message = (
327
                'Probability must be a number in [0, 1],'
328
                f' not {probability}'
329
            )
330
            raise ValueError(message)
331
        return probability
332
333
    @staticmethod
334
    def parse_include_and_exclude(
335
            include: TypeKeys = None,
336
            exclude: TypeKeys = None,
337
            ) -> Tuple[TypeKeys, TypeKeys]:
338
        if include is not None and exclude is not None:
339
            raise ValueError('Include and exclude cannot both be specified')
340
        return include, exclude
341
342
    @staticmethod
343
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
344
        return nib_to_sitk(data, affine)
345
346
    @staticmethod
347
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
348
        return sitk_to_nib(image)
349
350
    def _get_reproducing_arguments(self):
351
        """
352
        Return a dictionary with the arguments that would be necessary to
353
        reproduce the transform exactly.
354
        """
355
        reproducing_arguments = {
356
            'include': self.include,
357
            'exclude': self.exclude,
358
            'copy': self.copy,
359
        }
360
        args_names = {name: getattr(self, name) for name in self.args_names}
361
        reproducing_arguments.update(args_names)
362
        return reproducing_arguments
363
364
    def is_invertible(self):
365
        return hasattr(self, 'invert_transform')
366
367
    def inverse(self):
368
        if not self.is_invertible():
369
            raise RuntimeError(f'{self.name} is not invertible')
370
        new = copy.deepcopy(self)
371
        new.invert_transform = not self.invert_transform
372
        return new
373
374
    @staticmethod
375
    @contextmanager
376
    def _use_seed(seed):
377
        """Perform an operation using a specific seed for the PyTorch RNG"""
378
        torch_rng_state = torch.random.get_rng_state()
379
        torch.manual_seed(seed)
380
        yield
381
        torch.random.set_rng_state(torch_rng_state)
382
383
    @staticmethod
384
    def get_sitk_interpolator(interpolation: str) -> int:
385
        return get_sitk_interpolator(interpolation)
386
387
    @staticmethod
388
    def parse_bounds(bounds_parameters: TypeBounds) -> TypeSixBounds:
389
        if bounds_parameters is None:
390
            return None
391
        try:
392
            bounds_parameters = tuple(bounds_parameters)
393
        except TypeError:
394
            bounds_parameters = (bounds_parameters,)
395
396
        # Check that numbers are integers
397
        for number in bounds_parameters:
398
            if not isinstance(number, (int, np.integer)) or number < 0:
399
                message = (
400
                    'Bounds values must be integers greater or equal to zero,'
401
                    f' not "{bounds_parameters}" of type {type(number)}'
402
                )
403
                raise ValueError(message)
404
        bounds_parameters = tuple(int(n) for n in bounds_parameters)
405
        bounds_parameters_length = len(bounds_parameters)
406
        if bounds_parameters_length == 6:
407
            return bounds_parameters
408
        if bounds_parameters_length == 1:
409
            return 6 * bounds_parameters
410
        if bounds_parameters_length == 3:
411
            return tuple(np.repeat(bounds_parameters, 2).tolist())
412
        message = (
413
            'Bounds parameter must be an integer or a tuple of'
414
            f' 3 or 6 integers, not {bounds_parameters}'
415
        )
416
        raise ValueError(message)
417
418
    @staticmethod
419
    def ones(tensor: torch.Tensor) -> torch.Tensor:
420
        return torch.ones_like(tensor, dtype=torch.bool)
421
422
    @staticmethod
423
    def mean(tensor: torch.Tensor) -> torch.Tensor:
424
        mask = tensor > tensor.float().mean()
425
        return mask
426
427
    def get_mask_from_masking_method(
428
            self,
429
            masking_method: TypeMaskingMethod,
430
            subject: Subject,
431
            tensor: torch.Tensor,
432
            labels: list = None,
433
            ) -> torch.Tensor:
434
        if masking_method is None:
435
            return self.ones(tensor)
436
        elif callable(masking_method):
437
            return masking_method(tensor)
438
        elif type(masking_method) is str:
439
            in_subject = masking_method in subject
440
            if in_subject and isinstance(subject[masking_method], LabelMap):
441
                if labels is None:
442
                    return subject[masking_method].data.bool()
443
                else:
444
                    mask_data = subject[masking_method].data
445
                    volumes = [mask_data == label for label in labels]
446
                    return torch.stack(volumes).sum(0).bool()
447
            possible_axis = masking_method.capitalize()
448
            if possible_axis in ANATOMICAL_AXES:
449
                return self.get_mask_from_anatomical_label(
450
                    possible_axis, tensor)
451
        elif type(masking_method) in (tuple, list, int):
452
            return self.get_mask_from_bounds(masking_method, tensor)
453
        first_anat_axes = tuple(s[0] for s in ANATOMICAL_AXES)
454
        message = (
455
            'Masking method must be one of:\n'
456
            ' 1) A callable object, such as a function\n'
457
            ' 2) The name of a label map in the subject'
458
            f' ({subject.get_images_names()})\n'
459
            f' 3) An anatomical label {ANATOMICAL_AXES + first_anat_axes}\n'
460
            ' 4) A bounds parameter'
461
            ' (int, tuple of 3 ints, or tuple of 6 ints)\n'
462
            f' The passed value, "{masking_method}",'
463
            f' of type "{type(masking_method)}", is not valid'
464
        )
465
        raise ValueError(message)
466
467
    @staticmethod
468
    def get_mask_from_anatomical_label(
469
            anatomical_label: str,
470
            tensor: torch.Tensor,
471
            ) -> torch.Tensor:
472
        # Assume the image is in RAS orientation
473
        anatomical_label = anatomical_label.capitalize()
474
        if anatomical_label not in ANATOMICAL_AXES:
475
            message = (
476
                f'Anatomical label must be one of {ANATOMICAL_AXES}'
477
                f' not {anatomical_label}'
478
            )
479
            raise ValueError(message)
480
        mask = torch.zeros_like(tensor, dtype=torch.bool)
481
        _, width, height, depth = tensor.shape
482
        if anatomical_label == 'Right':
483
            mask[:, width // 2:] = True
484
        elif anatomical_label == 'Left':
485
            mask[:, :width // 2] = True
486
        elif anatomical_label == 'Anterior':
487
            mask[:, :, height // 2:] = True
488
        elif anatomical_label == 'Posterior':
489
            mask[:, :, :height // 2] = True
490
        elif anatomical_label == 'Superior':
491
            mask[:, :, :, depth // 2:] = True
492
        elif anatomical_label == 'Inferior':
493
            mask[:, :, :, :depth // 2] = True
494
        return mask
495
496
    def get_mask_from_bounds(
497
            self,
498
            bounds_parameters: TypeBounds,
499
            tensor: torch.Tensor,
500
            ) -> torch.Tensor:
501
        bounds_parameters = self.parse_bounds(bounds_parameters)
502
        low = bounds_parameters[::2]
503
        high = bounds_parameters[1::2]
504
        i0, j0, k0 = low
505
        i1, j1, k1 = np.array(tensor.shape[1:]) - high
506
        mask = torch.zeros_like(tensor, dtype=torch.bool)
507
        mask[:, i0:i1, j0:j1, k0:k1] = True
508
        return mask
509