Passed
Pull Request — master (#353)
by Fernando
01:11
created

torchio.transforms.augmentation.spatial.random_flip   A

Complexity

Total Complexity 16

Size/Duplication

Total Lines 125
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 69
dl 0
loc 125
rs 10
c 0
b 0
f 0
wmc 16

6 Methods

Rating   Name   Duplication   Size   Complexity  
A Flip.apply_transform() 0 5 2
A RandomFlip.__init__() 0 10 1
A RandomFlip.apply_transform() 0 12 3
A Flip.__init__() 0 3 1
A Flip.inverse() 0 2 1
A RandomFlip.get_params() 0 3 1

3 Functions

Rating   Name   Duplication   Size   Complexity  
A _ensure_axes_indices() 0 6 2
A _parse_axes() 0 13 4
A _flip_image() 0 7 1
1
from typing import Union, Tuple, Optional, List
2
import torch
3
import numpy as np
4
from ....torchio import DATA
5
from ....data.subject import Subject
6
from ....utils import to_tuple
7
from ... import SpatialTransform
8
from .. import RandomTransform
9
10
11
class RandomFlip(RandomTransform, SpatialTransform):
12
    """Reverse the order of elements in an image along the given axes.
13
14
    Args:
15
        axes: Index or tuple of indices of the spatial dimensions along which
16
            the image might be flipped. If they are integers, they must be in
17
            ``(0, 1, 2)``. Anatomical labels may also be used, such as
18
            ``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``,
19
            ``'Inferior'``, ``'Superior'``, ``'Height'`` and ``'Width'``,
20
            ``'AP'`` (antero-posterior), ``'lr'`` (lateral), ``'w'`` (width) or
21
            ``'i'`` (inferior). Only the first letter of the string will be
22
            used. If the image is 2D, ``'Height'`` and ``'Width'`` may be
23
            used.
24
        flip_probability: Probability that the image will be flipped. This is
25
            computed on a per-axis basis.
26
        p: Probability that this transform will be applied.
27
        keys: See :py:class:`~torchio.transforms.Transform`.
28
29
    Example:
30
        >>> import torchio as tio
31
        >>> fpg = tio.datasets.FPG()
32
        >>> flip = tio.RandomFlip(axes=('LR',))  # flip along lateral axis only
33
34
    .. tip:: It is handy to specify the axes as anatomical labels when the image
35
        orientation is not known.
36
    """
37
38
    def __init__(
39
            self,
40
            axes: Union[int, Tuple[int, ...]] = 0,
41
            flip_probability: float = 0.5,
42
            p: float = 1,
43
            keys: Optional[List[str]] = None,
44
            ):
45
        super().__init__(p=p, keys=keys)
46
        self.axes = _parse_axes(axes)
47
        self.flip_probability = self.parse_probability(flip_probability)
48
49
    def apply_transform(self, subject: Subject) -> Subject:
50
        potential_axes = _ensure_axes_indices(subject, self.axes)
51
        axes_to_flip_hot = self.get_params(self.flip_probability)
52
        for i in range(3):
53
            if i not in potential_axes:
54
                axes_to_flip_hot[i] = False
55
        axes, = np.where(axes_to_flip_hot)
56
        arguments = {'axes': axes.tolist()}
57
        transform = Flip(**arguments)
58
        transformed = transform(subject)
59
        transformed.add_transform(transform, arguments)
60
        return transformed
61
62
    @staticmethod
63
    def get_params(probability: float) -> List[bool]:
64
        return (probability > torch.rand(3)).tolist()
65
66
67
class Flip(SpatialTransform):
68
    """Reverse the order of elements in an image along the given axes.
69
70
    Args:
71
        axes: Index or tuple of indices of the spatial dimensions along which
72
            the image will be flipped. See
73
            :py:class:`~torchio.transforms.augmentation.spatial.random_flip.RandomFlip`
74
            for more information.
75
        keys: See :py:class:`~torchio.transforms.Transform`.
76
77
    .. tip:: It is handy to specify the axes as anatomical labels when the image
78
        orientation is not known.
79
    """
80
81
    def __init__(self, axes, keys: Optional[List[str]] = None):
82
        super().__init__(keys=keys)
83
        self.axes = _parse_axes(axes)
84
85
    def apply_transform(self, subject: Subject) -> Subject:
86
        axes = _ensure_axes_indices(subject, self.axes)
87
        for image in self.get_images(subject):
88
            _flip_image(image, axes)
89
        return subject
90
91
    def inverse(self):
92
        return self
93
94
95
def _parse_axes(axes: Union[int, Tuple[int, ...]]):
96
    axes_tuple = to_tuple(axes)
97
    for axis in axes_tuple:
98
        is_int = isinstance(axis, int)
99
        is_string = isinstance(axis, str)
100
        valid_number = is_int and axis in (0, 1, 2)
101
        if not is_string and not valid_number:
102
            message = (
103
                f'All axes must be 0, 1 or 2, but found "{axis}"'
104
                f' with type {type(axis)}'
105
            )
106
            raise ValueError(message)
107
    return axes_tuple
108
109
110
def _ensure_axes_indices(subject, axes):
111
    if any(isinstance(n, str) for n in axes):
112
        subject.check_consistent_orientation()
113
        image = subject.get_first_image()
114
        axes = sorted(3 + image.axis_name_to_index(n) for n in axes)
115
    return axes
116
117
118
def _flip_image(image, axes):
119
    spatial_axes = np.array(axes) + 1
120
    data = image.numpy()
121
    data = np.flip(data, axis=spatial_axes)
122
    data = data.copy()  # remove negative strides
123
    data = torch.from_numpy(data)
124
    image[DATA] = data
125