Passed
Pull Request — master (#353)
by Fernando
01:17
created

Compose.is_invertible()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
import warnings
2
from typing import Union, Sequence, Dict
3
4
import torch
5
import numpy as np
6
from torchvision.transforms import Compose as PyTorchCompose
7
8
from ...data.subject import Subject
9
from .. import Transform
10
from . import RandomTransform
11
12
13
TypeTransformsDict = Union[Dict[Transform, float], Sequence[Transform]]
14
15
16
class Compose(Transform):
17
    """Compose several transforms together.
18
19
    Args:
20
        transforms: Sequence of instances of
21
            :py:class:`~torchio.transforms.transform.Transform`.
22
        p: Probability that this transform will be applied.
23
24
    .. note::
25
        This is a thin wrapper of :py:class:`torchvision.transforms.Compose`.
26
    """
27
    def __init__(self, transforms: Sequence[Transform], p: float = 1):
28
        super().__init__(p=p)
29
        if not transforms:
30
            raise ValueError('The list of transforms is empty')
31
        for transform in transforms:
32
            if not callable(transform):
33
                message = (
34
                    'One or more of the objects passed to the Compose transform'
35
                    f' are not callable: "{transform}"'
36
                )
37
                raise TypeError(message)
38
        self.transform = PyTorchCompose(transforms)
39
        self.transforms = self.transform.transforms
40
41
    def __len__(self):
42
        return len(self.transforms)
43
44
    def __getitem__(self, index) -> Transform:
45
        return self.transforms[index]
46
47
    def __repr__(self) -> str:
48
        return self.transform.__repr__()
49
50
    def apply_transform(self, subject: Subject) -> Subject:
51
        return self.transform(subject)
52
53
    def is_invertible(self) -> bool:
54
        return all(t.is_invertible() for t in self.transforms)
55
56
    def inverse(self) -> Transform:
57
        transforms = []
58
        for transform in self.transforms:
59
            if transform.is_invertible():
60
                transforms.append(transform.inverse())
61
            else:
62
                message = f'Skipping {transform.name} as it is not invertible'
63
                warnings.warn(message, UserWarning)
64
        transforms.reverse()
65
        return Compose(transforms)
66
67
68
class OneOf(RandomTransform):
69
    """Apply only one of the given transforms.
70
71
    Args:
72
        transforms: Dictionary with instances of
73
            :py:class:`~torchio.transforms.transform.Transform` as keys and
74
            probabilities as values. Probabilities are normalized so they sum
75
            to one. If a sequence is given, the same probability will be
76
            assigned to each transform.
77
        p: Probability that this transform will be applied.
78
79
    Example:
80
        >>> import torchio as tio
81
        >>> colin = tio.datasets.Colin27()
82
        >>> transforms_dict = {
83
        ...     tio.RandomAffine(): 0.75,
84
        ...     tio.RandomElasticDeformation(): 0.25,
85
        ... }  # Using 3 and 1 as probabilities would have the same effect
86
        >>> transform = torchio.transforms.OneOf(transforms_dict)
87
        >>> transformed = transform(colin)
88
89
    """
90
    def __init__(
91
            self,
92
            transforms: TypeTransformsDict,
93
            p: float = 1,
94
            ):
95
        super().__init__(p=p)
96
        self.transforms_dict = self._get_transforms_dict(transforms)
97
98
    def apply_transform(self, subject: Subject) -> Subject:
99
        weights = torch.Tensor(list(self.transforms_dict.values()))
100
        index = torch.multinomial(weights, 1)
101
        transforms = list(self.transforms_dict.keys())
102
        transform = transforms[index]
103
        transformed = transform(subject)
104
        return transformed
105
106
    def _get_transforms_dict(
107
            self,
108
            transforms: TypeTransformsDict,
109
            ) -> Dict[Transform, float]:
110
        if isinstance(transforms, dict):
111
            transforms_dict = dict(transforms)
112
            self._normalize_probabilities(transforms_dict)
113
        else:
114
            try:
115
                p = 1 / len(transforms)
116
            except TypeError as e:
117
                message = (
118
                    'Transforms argument must be a dictionary or a sequence,'
119
                    f' not {type(transforms)}'
120
                )
121
                raise ValueError(message) from e
122
            transforms_dict = {transform: p for transform in transforms}
123
        for transform in transforms_dict:
124
            if not isinstance(transform, Transform):
125
                message = (
126
                    'All keys in transform_dict must be instances of'
127
                    f'torchio.Transform, not "{type(transform)}"'
128
                )
129
                raise ValueError(message)
130
        return transforms_dict
131
132
    @staticmethod
133
    def _normalize_probabilities(
134
            transforms_dict: Dict[Transform, float],
135
            ) -> None:
136
        probabilities = np.array(list(transforms_dict.values()), dtype=float)
137
        if np.any(probabilities < 0):
138
            message = (
139
                'Probabilities must be greater or equal to zero,'
140
                f' not "{probabilities}"'
141
            )
142
            raise ValueError(message)
143
        if np.all(probabilities == 0):
144
            message = (
145
                'At least one probability must be greater than zero,'
146
                f' but they are "{probabilities}"'
147
            )
148
            raise ValueError(message)
149
        for transform, probability in transforms_dict.items():
150
            transforms_dict[transform] = probability / probabilities.sum()
151