Passed
Pull Request — master (#625)
by Fernando
01:52
created

torchio.transforms.augmentation.spatial.random_flip._parse_axes()   A

Complexity

Conditions 4

Size

Total Lines 13
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 11
nop 1
dl 0
loc 13
rs 9.85
c 0
b 0
f 0
1
from torchio.data.image import Image
2
from typing import Union, Sequence, List
3
import torch
4
import numpy as np
5
from ....data.subject import Subject
6
from ... import SpatialTransform
7
from .. import RandomTransform
8
9
10
class RandomFlip(RandomTransform, SpatialTransform):
11
    """Reverse the order of elements in an image along the given axes.
12
13
    Args:
14
        axes: Index or tuple of indices of the spatial dimensions along which
15
            the image might be flipped. If they are integers, they must be in
16
            ``(0, 1, 2)``. Anatomical labels may also be used, such as
17
            ``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``,
18
            ``'Inferior'``, ``'Superior'``, ``'Height'`` and ``'Width'``,
19
            ``'AP'`` (antero-posterior), ``'lr'`` (lateral), ``'w'`` (width) or
20
            ``'i'`` (inferior). Only the first letter of the string will be
21
            used. If the image is 2D, ``'Height'`` and ``'Width'`` may be
22
            used.
23
        flip_probability: Probability that the image will be flipped. This is
24
            computed on a per-axis basis.
25
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
26
            keyword arguments.
27
28
    Example:
29
        >>> import torchio as tio
30
        >>> fpg = tio.datasets.FPG()
31
        >>> flip = tio.RandomFlip(axes=('LR',))  # flip along lateral axis only
32
33
    .. tip:: It is handy to specify the axes as anatomical labels when the
34
        image orientation is not known.
35
    """
36
37
    def __init__(
38
            self,
39
            axes: Union[int, Sequence[int], str, Sequence[str]] = 0,
40
            flip_probability: float = 0.5,
41
            **kwargs
42
            ):
43
        super().__init__(**kwargs)
44
        self.axes = self.parse_axes(axes)
45
        self.flip_probability = self.parse_probability(flip_probability)
46
47
    def apply_transform(self, subject: Subject) -> Subject:
48
        potential_axes = self.ensure_axes_indices(subject, self.axes)
49
        axes_to_flip_hot = self.get_params(self.flip_probability)
50
        for i in range(3):
51
            if i not in potential_axes:
52
                axes_to_flip_hot[i] = False
53
        axes, = np.where(axes_to_flip_hot)
54
        axes = axes.tolist()
55
        if not axes:
56
            return subject
57
58
        arguments = {'axes': axes}
59
        transform = Flip(**self.add_include_exclude(arguments))
60
        transformed = transform(subject)
61
        return transformed
62
63
    @staticmethod
64
    def get_params(probability: float) -> List[bool]:
65
        return (probability > torch.rand(3)).tolist()
66
67
68
class Flip(SpatialTransform):
69
    """Reverse the order of elements in an image along the given axes.
70
71
    Args:
72
        axes: Index or tuple of indices of the spatial dimensions along which
73
            the image will be flipped. See
74
            :class:`~torchio.transforms.augmentation.spatial.random_flip.RandomFlip`
75
            for more information.
76
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
77
            keyword arguments.
78
79
    .. tip:: It is handy to specify the axes as anatomical labels when the
80
        image orientation is not known.
81
    """
82
83
    def __init__(self, axes, **kwargs):
84
        super().__init__(**kwargs)
85
        self.axes = self.parse_axes(axes)
86
        self.args_names = ('axes',)
87
88
    def apply_transform(self, subject: Subject) -> Subject:
89
        axes = self.ensure_axes_indices(subject, self.axes)
90
        for image in self.get_images(subject):
91
            _flip_image(image, axes)
92
        return subject
93
94
    @staticmethod
95
    def is_invertible():
96
        return True
97
98
    def inverse(self):
99
        return self
100
101
102
def _flip_image(image: Image, axes: Sequence[int]) -> Image:
103
    spatial_axes = np.array(axes, int) + 1
104
    data = image.numpy()
105
    data = np.flip(data, axis=spatial_axes)
106
    data = data.copy()  # remove negative strides
107
    data = torch.as_tensor(data)
108
    image.set_data(data)
109