Passed
Pull Request — master (#353)
by Fernando
01:15
created

RandomAnisotropy.parse_axes()   A

Complexity

Conditions 4

Size

Total Lines 8
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 8
nop 1
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
import warnings
2
from typing import Union, Tuple, Optional, List, Sequence
3
4
import torch
5
6
from ....torchio import TypeRangeFloat
7
from ....data.subject import Subject
8
from ....utils import to_tuple
9
from ... import SpatialTransform
10
from .. import RandomTransform
11
from ...preprocessing import Resample
12
13
14
class RandomAnisotropy(RandomTransform, SpatialTransform):
15
    r"""Downsample an image along an axis and upsample to initial space.
16
17
    This transform simulates an image that has been acquired using anisotropic
18
    spacing and resampled back to a more standard spacing.
19
20
    Args:
21
        axes: Axis or tuple of axes along which the image will be downsampled.
22
        downsampling: Downsampling factor :math:`m \gt 1`. If a tuple
23
            :math:`(a, b)` is provided then :math:`m \sim \mathcal{U}(a, b)`.
24
        image_interpolation: Image interpolation used to upsample the image back
25
            to its initial spacing. Downsampling is performed using nearest
26
            neighbor interpolation. See :ref:`Interpolation` for supported
27
            interpolation types.
28
        p: Probability that this transform will be applied.
29
        keys: See :class:`~torchio.transforms.Transform`.
30
31
    Example:
32
        >>> import torchio as tio
33
        >>> transform = tio.RandomAnisotropy(axes=1, downsampling=2.)   # Multiply spacing of second axis by 2
34
        >>> transform = tio.RandomAnisotropy(
35
        ...     axes=(0, 1, 2),
36
        ...     downsampling=(2, 5),
37
        ... )   # Multiply spacing of one of the 3 axes by a factor randomly chosen in [2, 5]
38
        >>> colin = tio.datasets.Colin27()
39
        >>> transformed = transform(colin)  # images have now anisotropic spacing
40
    """
41
42
    def __init__(
43
            self,
44
            axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
45
            downsampling: TypeRangeFloat = (1.5, 5),
46
            image_interpolation: str = 'linear',
47
            p: float = 1,
48
            keys: Optional[Sequence[str]] = None,
49
            ):
50
        super().__init__(p=p, keys=keys)
51
        self.axes = self.parse_axes(axes)
52
        self.downsampling_range = self.parse_range(
53
            downsampling, 'downsampling', min_constraint=1)
54
        self.image_interpolation = self.parse_interpolation(image_interpolation)
55
56
    def get_params(
57
            self,
58
            axes: Tuple[int, ...],
59
            downsampling_range: Tuple[float, float],
60
            is_2d: bool,
61
            ) -> List[bool]:
62
        axis = axes[torch.randint(0, len(axes), (1,))]
63
        downsampling = self.sample_uniform(*downsampling_range).item()
64
        return axis, downsampling
65
66
    @staticmethod
67
    def parse_axes(axes: Union[int, Tuple[int, ...]]):
68
        axes_tuple = to_tuple(axes)
69
        for axis in axes_tuple:
70
            is_int = isinstance(axis, int)
71
            if not is_int or axis not in (0, 1, 2):
72
                raise ValueError('All axes must be 0, 1 or 2')
73
        return axes_tuple
74
75
    def apply_transform(self, subject: Subject) -> Subject:
76
        is_2d = subject.get_first_image().is_2d()
77
        if is_2d and 2 in self.axes:
78
            warnings.warn(f'Input image is 2D, but "2" is in axes: {self.axes}')
79
            self.axes = list(self.axes)
80
            self.axes.remove(2)
81
        axis, downsampling = self.get_params(
82
            self.axes,
83
            self.downsampling_range,
84
            is_2d,
85
        )
86
        target_spacing = list(subject.spacing)
87
        target_spacing[axis] *= downsampling
88
        downsample = Resample(
89
            tuple(target_spacing),
90
            image_interpolation='nearest',
91
        )
92
        downsampled = downsample(subject)
93
        upsample = Resample(
94
            subject.get_first_image(),
95
            image_interpolation=self.image_interpolation,
96
        )
97
        upsampled = upsample(downsampled)
98
        return upsampled
99