Passed
Push — master ( 7b848f...cac223 )
by Fernando
02:39
created

torchio.transforms.augmentation.spatial.random_flip   A

Complexity

Total Complexity 13

Size/Duplication

Total Lines 82
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 58
dl 0
loc 82
rs 10
c 0
b 0
f 0
wmc 13

4 Methods

Rating   Name   Duplication   Size   Complexity  
A RandomFlip.parse_axes() 0 8 4
A RandomFlip.__init__() 0 11 1
B RandomFlip.apply_transform() 0 29 6
A RandomFlip.get_params() 0 8 2
1
from typing import Union, Tuple, Optional, List
2
import torch
3
from ....torchio import DATA
4
from ....data.subject import Subject
5
from ....utils import to_tuple
6
from .. import RandomTransform
7
8
9
class RandomFlip(RandomTransform):
10
    """Reverse the order of elements in an image along the given axes.
11
12
    Args:
13
        axes: Axis or tuple of axes along which the image will be flipped.
14
        flip_probability: Probability that the image will be flipped. This is
15
            computed on a per-axis basis.
16
        p: Probability that this transform will be applied.
17
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
18
19
    .. note:: If the input image is 2D, all axes should be in ``(0, 1)``.
20
    """
21
22
    def __init__(
23
            self,
24
            axes: Union[int, Tuple[int, ...]] = 0,
25
            flip_probability: float = 0.5,
26
            p: float = 1,
27
            seed: Optional[int] = None,
28
            ):
29
        super().__init__(p=p, seed=seed)
30
        self.axes = self.parse_axes(axes)
31
        self.flip_probability = self.parse_probability(
32
            flip_probability,
33
        )
34
35
    def apply_transform(self, sample: Subject) -> dict:
36
        axes_to_flip_hot = self.get_params(self.axes, self.flip_probability)
37
        random_parameters_dict = {'axes': axes_to_flip_hot}
38
        items = sample.get_images_dict(intensity_only=False).items()
39
        for image_name, image_dict in items:
40
            data = image_dict[DATA]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
41
            is_2d = data.shape[-3] == 1
42
            dims = []
43
            for dim, flip_this in enumerate(axes_to_flip_hot):
44
                if not flip_this:
45
                    continue
46
                actual_dim = dim + 1  # images are 4D
47
                # If the user is using 2D images and they use (0, 1) for axes,
48
                # they probably mean (1, 2). This should make this transform
49
                # more user-friendly.
50
                if is_2d:
51
                    actual_dim += 1
52
                if actual_dim > 3:
53
                    message = (
54
                        f'Image "{image_name}" with shape {data.shape} seems to'
55
                        ' be 2D, so all axes must be in (0, 1),'
56
                        f' but they are {self.axes}'
57
                    )
58
                    raise RuntimeError(message)
59
                dims.append(actual_dim)
60
            data = torch.flip(data, dims=dims)
61
            image_dict[DATA] = data
62
        sample.add_transform(self, random_parameters_dict)
63
        return sample
64
65
    @staticmethod
66
    def get_params(axes: Tuple[int, ...], probability: float) -> List[bool]:
67
        axes_hot = [False, False, False]
68
        for axis in axes:
69
            random_number = torch.rand(1)
70
            flip_this = bool(probability > random_number)
71
            axes_hot[axis] = flip_this
72
        return axes_hot
73
74
    @staticmethod
75
    def parse_axes(axes: Union[int, Tuple[int, ...]]):
76
        axes_tuple = to_tuple(axes)
77
        for axis in axes_tuple:
78
            is_int = isinstance(axis, int)
79
            if not is_int or axis not in (0, 1, 2):
80
                raise ValueError('All axes must be 0, 1 or 2')
81
        return axes_tuple
82