Completed
Pull Request — master (#353)
by Fernando
118:39 queued 117:31
created

torchio.transforms.augmentation.spatial.random_anisotropy   A

Complexity

Total Complexity 9

Size/Duplication

Total Lines 107
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 64
dl 0
loc 107
rs 10
c 0
b 0
f 0
wmc 9

4 Methods

Rating   Name   Duplication   Size   Complexity  
A RandomAnisotropy.apply_transform() 0 26 3
A RandomAnisotropy.get_params() 0 9 1
A RandomAnisotropy.__init__() 0 15 1
A RandomAnisotropy.parse_axes() 0 8 4
1
import warnings
2
from typing import Union, Tuple, Optional, List, Sequence
3
4
import torch
5
6
from ....torchio import TypeRangeFloat
7
from ....data.subject import Subject
8
from ....utils import to_tuple
9
from .. import RandomTransform
10
from ...preprocessing import Resample
11
12
13
class RandomAnisotropy(RandomTransform):
14
    r"""Downsample an image along an axis and upsample to initial space.
15
16
    This transform simulates an image that has been acquired using anisotropic
17
    spacing and resampled back to its original spacing.
18
19
    Similar to the work by Billot et al.: `Partial Volume Segmentation of Brain
20
    MRI Scans of any Resolution and
21
    Contrast <https://link.springer.com/chapter/10.1007/978-3-030-59728-3_18>`_.
22
23
    Args:
24
        axes: Axis or tuple of axes along which the image will be downsampled.
25
        downsampling: Downsampling factor :math:`m \gt 1`. If a tuple
26
            :math:`(a, b)` is provided then :math:`m \sim \mathcal{U}(a, b)`.
27
        image_interpolation: Image interpolation used to upsample the image back
28
            to its initial spacing. Downsampling is performed using nearest
29
            neighbor interpolation. See :ref:`Interpolation` for supported
30
            interpolation types.
31
        scalars_only: Apply only to instances of :class:`torchio.ScalarImage`.
32
        p: Probability that this transform will be applied.
33
        keys: See :class:`~torchio.transforms.Transform`.
34
35
    Example:
36
        >>> import torchio as tio
37
        >>> transform = tio.RandomAnisotropy(axes=1, downsampling=2)
38
        >>> transform = tio.RandomAnisotropy(
39
        ...     axes=(0, 1, 2),
40
        ...     downsampling=(2, 5),
41
        ... )   # Multiply spacing of one of the 3 axes by a factor randomly chosen in [2, 5]
42
        >>> colin = tio.datasets.Colin27()
43
        >>> transformed = transform(colin)
44
    """
45
46
    def __init__(
47
            self,
48
            axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
49
            downsampling: TypeRangeFloat = (1.5, 5),
50
            image_interpolation: str = 'linear',
51
            scalars_only: bool = True,
52
            p: float = 1,
53
            keys: Optional[Sequence[str]] = None,
54
            ):
55
        super().__init__(p=p, keys=keys)
56
        self.axes = self.parse_axes(axes)
57
        self.downsampling_range = self.parse_range(
58
            downsampling, 'downsampling', min_constraint=1)
59
        self.image_interpolation = self.parse_interpolation(image_interpolation)
60
        self.scalars_only = scalars_only
61
62
    def get_params(
63
            self,
64
            axes: Tuple[int, ...],
65
            downsampling_range: Tuple[float, float],
66
            is_2d: bool,
67
            ) -> List[bool]:
68
        axis = axes[torch.randint(0, len(axes), (1,))]
69
        downsampling = self.sample_uniform(*downsampling_range).item()
70
        return axis, downsampling
71
72
    @staticmethod
73
    def parse_axes(axes: Union[int, Tuple[int, ...]]):
74
        axes_tuple = to_tuple(axes)
75
        for axis in axes_tuple:
76
            is_int = isinstance(axis, int)
77
            if not is_int or axis not in (0, 1, 2):
78
                raise ValueError('All axes must be 0, 1 or 2')
79
        return axes_tuple
80
81
    def apply_transform(self, subject: Subject) -> Subject:
82
        is_2d = subject.get_first_image().is_2d()
83
        if is_2d and 2 in self.axes:
84
            warnings.warn(f'Input image is 2D, but "2" is in axes: {self.axes}')
85
            self.axes = list(self.axes)
86
            self.axes.remove(2)
87
        axis, downsampling = self.get_params(
88
            self.axes,
89
            self.downsampling_range,
90
            is_2d,
91
        )
92
        target_spacing = list(subject.spacing)
93
        target_spacing[axis] *= downsampling
94
        downsample = Resample(
95
            tuple(target_spacing),
96
            image_interpolation='nearest',
97
            scalars_only=self.scalars_only,
98
        )
99
        downsampled = downsample(subject)
100
        upsample = Resample(
101
            subject.get_first_image(),
102
            image_interpolation=self.image_interpolation,
103
            scalars_only=self.scalars_only,
104
        )
105
        upsampled = upsample(downsampled)
106
        return upsampled
107