Passed
Pull Request — master (#353)
by Fernando
01:09
created

Flip.inverse()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
from typing import Union, Tuple, Optional, List
2
import torch
3
import numpy as np
4
from ....torchio import DATA
5
from ....data.subject import Subject
6
from ....utils import to_tuple
7
from ... import SpatialTransform
8
from .. import RandomTransform
9
10
11
class RandomFlip(RandomTransform, SpatialTransform):
12
    """Reverse the order of elements in an image along the given axes.
13
14
    Args:
15
        axes: Index or tuple of indices of the spatial dimensions along which
16
            the image might be flipped. If they are integers, they must be in
17
            ``(0, 1, 2)``. Anatomical labels may also be used, such as
18
            ``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``,
19
            ``'Inferior'``, ``'Superior'``, ``'Height'`` and ``'Width'``,
20
            ``'AP'`` (antero-posterior), ``'lr'`` (lateral), ``'w'`` (width) or
21
            ``'i'`` (inferior). Only the first letter of the string will be
22
            used. If the image is 2D, ``'Height'`` and ``'Width'`` may be
23
            used.
24
        flip_probability: Probability that the image will be flipped. This is
25
            computed on a per-axis basis.
26
        p: Probability that this transform will be applied.
27
        keys: See :py:class:`~torchio.transforms.Transform`.
28
29
    Example:
30
        >>> import torchio as tio
31
        >>> fpg = tio.datasets.FPG()
32
        >>> flip = tio.RandomFlip(axes=('LR',))  # flip along lateral axis only
33
34
    .. tip:: It is handy to specify the axes as anatomical labels when the image
35
        orientation is not known.
36
    """
37
38
    def __init__(
39
            self,
40
            axes: Union[int, Tuple[int, ...]] = 0,
41
            flip_probability: float = 0.5,
42
            p: float = 1,
43
            keys: Optional[List[str]] = None,
44
            ):
45
        super().__init__(p=p, keys=keys)
46
        self.axes = _parse_axes(axes)
47
        self.flip_probability = self.parse_probability(flip_probability)
48
49
    def apply_transform(self, subject: Subject) -> Subject:
50
        potential_axes = _ensure_axes_indices(subject, self.axes)
51
        axes_to_flip_hot = self.get_params(self.flip_probability)
52
        for i in range(3):
53
            if i not in potential_axes:
54
                axes_to_flip_hot[i] = False
55
        axes, = np.where(axes_to_flip_hot)
56
        transform = Flip(axes=axes.tolist())
57
        transformed = transform(subject)
58
        return transformed
59
60
    @staticmethod
61
    def get_params(probability: float) -> List[bool]:
62
        return (probability > torch.rand(3)).tolist()
63
64
65
class Flip(SpatialTransform):
66
    """Reverse the order of elements in an image along the given axes.
67
68
    Args:
69
        axes: Index or tuple of indices of the spatial dimensions along which
70
            the image will be flipped. See
71
            :py:class:`~torchio.transforms.augmentation.spatial.random_flip.RandomFlip`
72
            for more information.
73
        keys: See :py:class:`~torchio.transforms.Transform`.
74
75
    .. tip:: It is handy to specify the axes as anatomical labels when the image
76
        orientation is not known.
77
    """
78
79
    def __init__(self, axes, keys: Optional[List[str]] = None):
80
        super().__init__(keys=keys)
81
        self.axes = _parse_axes(axes)
82
83
    def apply_transform(self, subject: Subject) -> Subject:
84
        axes = _ensure_axes_indices(subject, self.axes)
85
        for image in self.get_images(subject):
86
            _flip_image(image, axes)
87
        return subject
88
89
    @staticmethod
90
    def is_invertible():
91
        return True
92
93
    def inverse(self):
94
        return self
95
96
97
def _parse_axes(axes: Union[int, Tuple[int, ...]]):
98
    axes_tuple = to_tuple(axes)
99
    for axis in axes_tuple:
100
        is_int = isinstance(axis, int)
101
        is_string = isinstance(axis, str)
102
        valid_number = is_int and axis in (0, 1, 2)
103
        if not is_string and not valid_number:
104
            message = (
105
                f'All axes must be 0, 1 or 2, but found "{axis}"'
106
                f' with type {type(axis)}'
107
            )
108
            raise ValueError(message)
109
    return axes_tuple
110
111
112
def _ensure_axes_indices(subject, axes):
113
    if any(isinstance(n, str) for n in axes):
114
        subject.check_consistent_orientation()
115
        image = subject.get_first_image()
116
        axes = sorted(3 + image.axis_name_to_index(n) for n in axes)
117
    return axes
118
119
120
def _flip_image(image, axes):
121
    spatial_axes = np.array(axes) + 1
122
    data = image.numpy()
123
    data = np.flip(data, axis=spatial_axes)
124
    data = data.copy()  # remove negative strides
125
    data = torch.from_numpy(data)
126
    image[DATA] = data
127