Completed
Pull Request — master (#353)
by Fernando
118:39 queued 117:31
created

torchio.transforms.augmentation.spatial.random_flip   A

Complexity

Total Complexity 17

Size/Duplication

Total Lines 128
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 71
dl 0
loc 128
rs 10
c 0
b 0
f 0
wmc 17

7 Methods

Rating   Name   Duplication   Size   Complexity  
A RandomFlip.__init__() 0 10 1
A RandomFlip.apply_transform() 0 10 3
A RandomFlip.get_params() 0 3 1
A Flip.apply_transform() 0 5 2
A Flip.__init__() 0 4 1
A Flip.inverse() 0 2 1
A Flip.is_invertible() 0 3 1

3 Functions

Rating   Name   Duplication   Size   Complexity  
A _ensure_axes_indices() 0 6 2
A _parse_axes() 0 13 4
A _flip_image() 0 7 1
1
from typing import Union, Tuple, Optional, List, Sequence
2
import torch
3
import numpy as np
4
from ....torchio import DATA
5
from ....data.subject import Subject
6
from ....utils import to_tuple
7
from ... import SpatialTransform
8
from .. import RandomTransform
9
10
11
class RandomFlip(RandomTransform, SpatialTransform):
12
    """Reverse the order of elements in an image along the given axes.
13
14
    Args:
15
        axes: Index or tuple of indices of the spatial dimensions along which
16
            the image might be flipped. If they are integers, they must be in
17
            ``(0, 1, 2)``. Anatomical labels may also be used, such as
18
            ``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``,
19
            ``'Inferior'``, ``'Superior'``, ``'Height'`` and ``'Width'``,
20
            ``'AP'`` (antero-posterior), ``'lr'`` (lateral), ``'w'`` (width) or
21
            ``'i'`` (inferior). Only the first letter of the string will be
22
            used. If the image is 2D, ``'Height'`` and ``'Width'`` may be
23
            used.
24
        flip_probability: Probability that the image will be flipped. This is
25
            computed on a per-axis basis.
26
        p: Probability that this transform will be applied.
27
        keys: See :class:`~torchio.transforms.Transform`.
28
29
    Example:
30
        >>> import torchio as tio
31
        >>> fpg = tio.datasets.FPG()
32
        >>> flip = tio.RandomFlip(axes=('LR',))  # flip along lateral axis only
33
34
    .. tip:: It is handy to specify the axes as anatomical labels when the image
35
        orientation is not known.
36
    """
37
38
    def __init__(
39
            self,
40
            axes: Union[int, Tuple[int, ...]] = 0,
41
            flip_probability: float = 0.5,
42
            p: float = 1,
43
            keys: Optional[Sequence[str]] = None,
44
            ):
45
        super().__init__(p=p, keys=keys)
46
        self.axes = _parse_axes(axes)
47
        self.flip_probability = self.parse_probability(flip_probability)
48
49
    def apply_transform(self, subject: Subject) -> Subject:
50
        potential_axes = _ensure_axes_indices(subject, self.axes)
51
        axes_to_flip_hot = self.get_params(self.flip_probability)
52
        for i in range(3):
53
            if i not in potential_axes:
54
                axes_to_flip_hot[i] = False
55
        axes, = np.where(axes_to_flip_hot)
56
        transform = Flip(axes=axes.tolist())
57
        transformed = transform(subject)
58
        return transformed
59
60
    @staticmethod
61
    def get_params(probability: float) -> List[bool]:
62
        return (probability > torch.rand(3)).tolist()
63
64
65
class Flip(SpatialTransform):
66
    """Reverse the order of elements in an image along the given axes.
67
68
    Args:
69
        axes: Index or tuple of indices of the spatial dimensions along which
70
            the image will be flipped. See
71
            :class:`~torchio.transforms.augmentation.spatial.random_flip.RandomFlip`
72
            for more information.
73
        keys: See :class:`~torchio.transforms.Transform`.
74
75
    .. tip:: It is handy to specify the axes as anatomical labels when the image
76
        orientation is not known.
77
    """
78
79
    def __init__(self, axes, keys: Optional[Sequence[str]] = None):
80
        super().__init__(keys=keys)
81
        self.axes = _parse_axes(axes)
82
        self.args_names = ('axes',)
83
84
    def apply_transform(self, subject: Subject) -> Subject:
85
        axes = _ensure_axes_indices(subject, self.axes)
86
        for image in self.get_images(subject):
87
            _flip_image(image, axes)
88
        return subject
89
90
    @staticmethod
91
    def is_invertible():
92
        return True
93
94
    def inverse(self):
95
        return self
96
97
98
def _parse_axes(axes: Union[int, Tuple[int, ...]]):
99
    axes_tuple = to_tuple(axes)
100
    for axis in axes_tuple:
101
        is_int = isinstance(axis, int)
102
        is_string = isinstance(axis, str)
103
        valid_number = is_int and axis in (0, 1, 2)
104
        if not is_string and not valid_number:
105
            message = (
106
                f'All axes must be 0, 1 or 2, but found "{axis}"'
107
                f' with type {type(axis)}'
108
            )
109
            raise ValueError(message)
110
    return axes_tuple
111
112
113
def _ensure_axes_indices(subject, axes):
114
    if any(isinstance(n, str) for n in axes):
115
        subject.check_consistent_orientation()
116
        image = subject.get_first_image()
117
        axes = sorted(3 + image.axis_name_to_index(n) for n in axes)
118
    return axes
119
120
121
def _flip_image(image, axes):
122
    spatial_axes = np.array(axes) + 1
123
    data = image.numpy()
124
    data = np.flip(data, axis=spatial_axes)
125
    data = data.copy()  # remove negative strides
126
    data = torch.from_numpy(data)
127
    image[DATA] = data
128