Passed
Pull Request — master (#353)
by Fernando
01:16
created

RandomAnisotropy.apply_transform()   A

Complexity

Conditions 1

Size

Total Lines 15
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 13
nop 2
dl 0
loc 15
rs 9.75
c 0
b 0
f 0
1
from typing import Union, Tuple, Optional, List, Sequence
2
import torch
3
from ....torchio import TypeRangeFloat
4
from ....data.subject import Subject
5
from ....utils import to_tuple
6
from ... import SpatialTransform
7
from .. import RandomTransform
8
from ...preprocessing import Resample
9
10
11
class RandomAnisotropy(RandomTransform, SpatialTransform):
12
    r"""Downsample an image along an axis and upsample to initial space.
13
14
    This transform simulates an image that has been acquired using anisotropic
15
    spacing and resampled back to a more standard spacing.
16
17
    Args:
18
        axes: Axis or tuple of axes along which the image will be downsampled.
19
        downsampling: Downsampling factor :math:`m \gt 1`. If a tuple
20
            :math:`(a, b)` is provided then :math:`m \sim \mathcal{U}(a, b)`.
21
        image_interpolation: Image interpolation used to upsample the image back
22
            to its initial spacing. Downsampling is performed using nearest
23
            neighbor interpolation. See :ref:`Interpolation` for supported
24
            interpolation types.
25
        p: Probability that this transform will be applied.
26
        keys: See :py:class:`~torchio.transforms.Transform`.
27
28
    Example:
29
        >>> import torchio as tio
30
        >>> transform = tio.RandomAnisotropy(axes=1, downsampling=2.)   # Multiply spacing of second axis by 2
31
        >>> transform = tio.RandomAnisotropy(
32
        ...     axes=(0, 1, 2),
33
        ...     downsampling=(2, 5),
34
        ... )   # Multiply spacing of one of the 3 axes by a factor randomly chosen in [2, 5]
35
        >>> colin = tio.datasets.Colin27()
36
        >>> transformed = transform(colin)  # images have now anisotropic spacing
37
    """
38
39
    def __init__(
40
            self,
41
            axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
42
            downsampling: TypeRangeFloat = (1.5, 5),
43
            image_interpolation: str = 'linear',
44
            p: float = 1,
45
            keys: Optional[Sequence[str]] = None,
46
            ):
47
        super().__init__(p=p, keys=keys)
48
        self.axes = self.parse_axes(axes)
49
        self.downsampling_range = self.parse_range(
50
            downsampling, 'downsampling', min_constraint=1)
51
        self.image_interpolation = self.parse_interpolation(image_interpolation)
52
53
    def get_params(
54
            self,
55
            axes: Tuple[int, ...],
56
            downsampling_range: Tuple[float, float],
57
            ) -> List[bool]:
58
        axis = axes[torch.randint(0, len(axes), (1,))]
59
        downsampling = self.sample_uniform(*downsampling_range).item()
60
        return axis, downsampling
61
62
    @staticmethod
63
    def parse_axes(axes: Union[int, Tuple[int, ...]]):
64
        axes_tuple = to_tuple(axes)
65
        for axis in axes_tuple:
66
            is_int = isinstance(axis, int)
67
            if not is_int or axis not in (0, 1, 2):
68
                raise ValueError('All axes must be 0, 1 or 2')
69
        return axes_tuple
70
71
    def apply_transform(self, subject: Subject) -> Subject:
72
        axis, downsampling = self.get_params(self.axes, self.downsampling_range)
73
        target_spacing = list(subject.spacing)
74
        target_spacing[axis] *= downsampling
75
        downsample = Resample(
76
            tuple(target_spacing),
77
            image_interpolation='nearest',
78
        )
79
        downsampled = downsample(subject)
80
        upsample = Resample(
81
            subject.get_first_image(),
82
            image_interpolation=self.image_interpolation,
83
        )
84
        upsampled = upsample(downsampled)
85
        return upsampled
86