Passed
Push — master ( 066103...5d2729 )
by Fernando
01:22
created

RandomDownsample.get_params()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 7
nop 2
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
from typing import Union, Tuple, Optional, List
2
import torch
3
from ....torchio import DATA
4
from ....data.subject import Subject
5
from ....utils import to_tuple
6
from .. import RandomTransform
7
from ...preprocessing import Resample
8
9
10
class RandomDownsample(RandomTransform):
11
    """Downsample an image along an axis.
12
13
    This transform simulates an image that has been acquired using anisotropic
14
    spacing, using downsampling with nearest neighbor interpolation.
15
16
    Args:
17
        axes: Axis or tuple of axes along which the image will be downsampled.
18
        downsampling: Downsampling factor :math:`m \gt 1`. If a tuple
19
            :math:`(a, b)` is provided then :math:`m \sim \mathcal{U}(a, b)`.
20
        p: Probability that this transform will be applied.
21
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
22
    """
23
24
    def __init__(
25
            self,
26
            axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
27
            downsampling: float = (1.5, 5),
28
            p: float = 1,
29
            seed: Optional[int] = None,
30
            ):
31
        super().__init__(p=p, seed=seed)
32
        self.axes = self.parse_axes(axes)
33
        self.downsampling_range = self.parse_downsampling(downsampling)
34
35
    @staticmethod
36
    def get_params(
37
            axes: Tuple[int, ...],
38
            downsampling_range: Tuple[float, float],
39
            ) -> List[bool]:
40
        axis = axes[torch.randint(0, len(axes), (1,))]
41
        downsampling = torch.FloatTensor(1).uniform_(*downsampling_range).item()
42
        return axis, downsampling
43
44
    @staticmethod
45
    def parse_downsampling(downsampling_factor):
46
        try:
47
            iter(downsampling_factor)
48
        except TypeError:
49
            downsampling_factor = downsampling_factor, downsampling_factor
50
        for n in downsampling_factor:
51
            if n <= 1:
52
                message = (
53
                    f'Downsampling factor must be a number > 1, not {n}')
54
                raise ValueError(message)
55
        return downsampling_factor
56
57
    @staticmethod
58
    def parse_axes(axes: Union[int, Tuple[int, ...]]):
59
        axes_tuple = to_tuple(axes)
60
        for axis in axes_tuple:
61
            is_int = isinstance(axis, int)
62
            if not is_int or axis not in (0, 1, 2):
63
                raise ValueError('All axes must be 0, 1 or 2')
64
        return axes_tuple
65
66
    def apply_transform(self, sample: Subject) -> Subject:
67
        axis, downsampling = self.get_params(self.axes, self.downsampling_range)
68
        random_parameters_dict = {'axis': axis, 'downsampling': downsampling}
69
        items = sample.get_images_dict(intensity_only=False).items()
70
71
        target_spacing = list(sample.spacing)
72
        target_spacing[axis] *= downsampling
73
        transform = Resample(
74
            tuple(target_spacing),
75
            image_interpolation='nearest',
76
            copy=False,  # already copied in super().__init__
77
        )
78
        sample = transform(sample)
79
        sample.add_transform(self, random_parameters_dict)
80
        return sample
81