Passed
Push — master ( c291a8...879ee9 )
by Fernando
59s
created

torchio.transforms.augmentation.spatial.random_downsample   A

Complexity

Total Complexity 7

Size/Duplication

Total Lines 78
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 47
dl 0
loc 78
rs 10
c 0
b 0
f 0
wmc 7

4 Methods

Rating   Name   Duplication   Size   Complexity  
A RandomDownsample.parse_axes() 0 8 4
A RandomDownsample.apply_transform() 0 14 1
A RandomDownsample.__init__() 0 11 1
A RandomDownsample.get_params() 0 8 1
1
from typing import Union, Tuple, Optional, List
2
import torch
3
from ....torchio import TypeRangeFloat
4
from ....data.subject import Subject
5
from ....utils import to_tuple
6
from .. import RandomTransform
7
from ...preprocessing import Resample
8
9
10
class RandomDownsample(RandomTransform):
11
    """Downsample an image along an axis.
12
13
    This transform simulates an image that has been acquired using anisotropic
14
    spacing, using downsampling with nearest neighbor interpolation.
15
16
    Args:
17
        axes: Axis or tuple of axes along which the image will be downsampled.
18
        downsampling: Downsampling factor :math:`m \gt 1`. If a tuple
19
            :math:`(a, b)` is provided then :math:`m \sim \mathcal{U}(a, b)`.
20
        p: Probability that this transform will be applied.
21
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
22
23
    Example:
24
        >>> from torchio import RandomDownsample
25
        >>> from torchio.datasets import Colin27
26
        >>> transform = RandomDownsample(axes=1, downsampling=2.)   # Multiply spacing of second axis by 2
27
        >>> transform = RandomDownsample(
28
        ...     axes=(0, 1, 2), downsampling=(2, 5)
29
        ... )   # Multiply spacing of one of the 3 axes by a factor randomly chosen in [2, 5]
30
        >>> colin = Colin27
31
        >>> transformed = transform(colin)  # images have now anisotropic spacing
32
    """
33
34
    def __init__(
35
            self,
36
            axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
37
            downsampling: TypeRangeFloat = (1.5, 5),
38
            p: float = 1,
39
            seed: Optional[int] = None,
40
            ):
41
        super().__init__(p=p, seed=seed)
42
        self.axes = self.parse_axes(axes)
43
        self.downsampling_range = self.parse_range(
44
            downsampling, 'downsampling', min_constraint=1)
45
46
    @staticmethod
47
    def get_params(
48
            axes: Tuple[int, ...],
49
            downsampling_range: Tuple[float, float],
50
            ) -> List[bool]:
51
        axis = axes[torch.randint(0, len(axes), (1,))]
52
        downsampling = torch.FloatTensor(1).uniform_(*downsampling_range).item()
53
        return axis, downsampling
54
55
    @staticmethod
56
    def parse_axes(axes: Union[int, Tuple[int, ...]]):
57
        axes_tuple = to_tuple(axes)
58
        for axis in axes_tuple:
59
            is_int = isinstance(axis, int)
60
            if not is_int or axis not in (0, 1, 2):
61
                raise ValueError('All axes must be 0, 1 or 2')
62
        return axes_tuple
63
64
    def apply_transform(self, sample: Subject) -> Subject:
65
        axis, downsampling = self.get_params(self.axes, self.downsampling_range)
66
        random_parameters_dict = {'axis': axis, 'downsampling': downsampling}
67
68
        target_spacing = list(sample.spacing)
69
        target_spacing[axis] *= downsampling
70
        transform = Resample(
71
            tuple(target_spacing),
72
            image_interpolation='nearest',
73
            copy=False,  # already copied in super().__init__
74
        )
75
        sample = transform(sample)
76
        sample.add_transform(self, random_parameters_dict)
77
        return sample
78