Passed
Pull Request — master (#225)
by Fernando
01:58
created

RandomDownsample.__init__()   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 9
nop 5
dl 0
loc 10
rs 9.95
c 0
b 0
f 0
1
from typing import Union, Tuple, Optional, List
2
import torch
3
from ....torchio import DATA
4
from ....data.subject import Subject
5
from ....utils import to_tuple
6
from .. import RandomTransform
7
from ...preprocessing import Resample
8
9
10
class RandomDownsample(RandomTransform):
11
    """Downsample an image along an axis.
12
13
    This transform simulates an image that has been acquired using anisotropic
14
    spacing, using downsampling with nearest neighbor interpolation.
15
16
    .. image:: https://user-images.githubusercontent.com/12688084/87075276-fe9d9d00-c217-11ea-81a4-db0cac163ce7.png
17
        :alt: Simulation of an image with highly anisotropic spacing
18
19
    Args:
20
        axes: Axis or tuple of axes along which the image will be downsampled.
21
        downsampling: Downsampling factor :math:`f`.
22
        p: Probability that this transform will be applied.
23
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
24
    """
25
26
    def __init__(
27
            self,
28
            axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
29
            downsampling: float = (1.5, 5),
30
            p: float = 1,
31
            seed: Optional[int] = None,
32
            ):
33
        super().__init__(p=p, seed=seed)
34
        self.axes = self.parse_axes(axes)
35
        self.downsampling_range = self.parse_downsampling(downsampling)
36
37
    @staticmethod
38
    def get_params(
39
            axes: Tuple[int, ...],
40
            downsampling_range: Tuple[float, float],
41
            ) -> List[bool]:
42
        axis = axes[torch.randint(0, len(axes), (1,))]
43
        downsampling = torch.FloatTensor(1).uniform_(*downsampling_range).item()
44
        return axis, downsampling
45
46
    @staticmethod
47
    def parse_downsampling(downsampling_factor):
48
        try:
49
            iter(downsampling_factor)
50
        except TypeError:
51
            downsampling_factor = downsampling_factor, downsampling_factor
52
        for n in downsampling_factor:
53
            if n <= 1:
54
                message = (
55
                    f'Downsampling factor must be a number > 1, not {n}')
56
                raise ValueError(message)
57
        return downsampling_factor
58
59
    @staticmethod
60
    def parse_axes(axes: Union[int, Tuple[int, ...]]):
61
        axes_tuple = to_tuple(axes)
62
        for axis in axes_tuple:
63
            is_int = isinstance(axis, int)
64
            if not is_int or axis not in (0, 1, 2):
65
                raise ValueError('All axes must be 0, 1 or 2')
66
        return axes_tuple
67
68
    def apply_transform(self, sample: Subject) -> Subject:
69
        axis, downsampling = self.get_params(self.axes, self.downsampling_range)
70
        random_parameters_dict = {'axis': axis, 'downsampling': downsampling}
71
        items = sample.get_images_dict(intensity_only=False).items()
72
73
        target_spacing = list(sample.spacing)
74
        target_spacing[axis] *= downsampling
75
        transform = Resample(
76
            tuple(target_spacing),
77
            image_interpolation='nearest',
78
        )
79
        sample = transform(sample)
80
        sample.add_transform(self, random_parameters_dict)
81
        return sample
82