Completed
Pull Request — master (#402)
by Fernando
04:20 queued 54s
created

Compose.inverse()   B

Complexity

Conditions 6

Size

Total Lines 21
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 15
nop 2
dl 0
loc 21
rs 8.6666
c 0
b 0
f 0
1
import warnings
2
from typing import Union, Sequence, Dict
3
4
import torch
5
import numpy as np
6
7
from ...data.subject import Subject
8
from .. import Transform
9
from . import RandomTransform
10
11
12
TypeTransformsDict = Union[Dict[Transform, float], Sequence[Transform]]
13
14
15
class Compose(Transform):
16
    """Compose several transforms together.
17
18
    Args:
19
        transforms: Sequence of instances of
20
            :class:`~torchio.transforms.Transform`.
21
        **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments.
22
23
    """
24
    def __init__(self, transforms: Sequence[Transform], **kwargs):
25
        super().__init__(**kwargs)
26
        if not transforms:
27
            raise ValueError('The list of transforms is empty')
28
        for transform in transforms:
29
            if not callable(transform):
30
                message = (
31
                    'One or more of the objects passed to the Compose transform'
32
                    f' are not callable: "{transform}"'
33
                )
34
                raise TypeError(message)
35
        self.transforms = transforms
36
37
    def __len__(self):
38
        return len(self.transforms)
39
40
    def __getitem__(self, index) -> Transform:
41
        return self.transforms[index]
42
43
    def __repr__(self) -> str:
44
        return repr(self.transforms)
45
46
    def apply_transform(self, subject: Subject) -> Subject:
47
        for transform in self.transforms:
48
            subject = transform(subject)
49
        return subject
50
51
    def is_invertible(self) -> bool:
52
        return all(t.is_invertible() for t in self.transforms)
53
54
    def inverse(self, warn: bool = True) -> Transform:
55
        """Return a composed transform with inverted order and transforms.
56
57
        Args:
58
            warn: Issue a warning if some transforms are not invertible.
59
        """
60
        transforms = []
61
        for transform in self.transforms:
62
            if transform.is_invertible():
63
                transforms.append(transform.inverse())
64
            elif warn:
65
                message = f'Skipping {transform.name} as it is not invertible'
66
                warnings.warn(message, RuntimeWarning)
67
        transforms.reverse()
68
        if transforms:
69
            result = Compose(transforms)
70
        else:  # return noop if no invertible transforms are found
71
            def result(x): return x  # noqa: E704
72
            if warn:
73
                warnings.warn('No invertible transforms found', RuntimeWarning)
74
        return result
75
76
77
class OneOf(RandomTransform):
78
    """Apply only one of the given transforms.
79
80
    Args:
81
        transforms: Dictionary with instances of
82
            :class:`~torchio.transforms.Transform` as keys and
83
            probabilities as values. Probabilities are normalized so they sum
84
            to one. If a sequence is given, the same probability will be
85
            assigned to each transform.
86
        **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments.
87
88
    Example:
89
        >>> import torchio as tio
90
        >>> colin = tio.datasets.Colin27()
91
        >>> transforms_dict = {
92
        ...     tio.RandomAffine(): 0.75,
93
        ...     tio.RandomElasticDeformation(): 0.25,
94
        ... }  # Using 3 and 1 as probabilities would have the same effect
95
        >>> transform = tio.OneOf(transforms_dict)
96
        >>> transformed = transform(colin)
97
98
    """
99
    def __init__(
100
            self,
101
            transforms: TypeTransformsDict,
102
            **kwargs
103
            ):
104
        super().__init__(**kwargs)
105
        self.transforms_dict = self._get_transforms_dict(transforms)
106
107
    def apply_transform(self, subject: Subject) -> Subject:
108
        weights = torch.Tensor(list(self.transforms_dict.values()))
109
        index = torch.multinomial(weights, 1)
110
        transforms = list(self.transforms_dict.keys())
111
        transform = transforms[index]
112
        transformed = transform(subject)
113
        return transformed
114
115
    def _get_transforms_dict(
116
            self,
117
            transforms: TypeTransformsDict,
118
            ) -> Dict[Transform, float]:
119
        if isinstance(transforms, dict):
120
            transforms_dict = dict(transforms)
121
            self._normalize_probabilities(transforms_dict)
122
        else:
123
            try:
124
                p = 1 / len(transforms)
125
            except TypeError as e:
126
                message = (
127
                    'Transforms argument must be a dictionary or a sequence,'
128
                    f' not {type(transforms)}'
129
                )
130
                raise ValueError(message) from e
131
            transforms_dict = {transform: p for transform in transforms}
132
        for transform in transforms_dict:
133
            if not isinstance(transform, Transform):
134
                message = (
135
                    'All keys in transform_dict must be instances of'
136
                    f'torchio.Transform, not "{type(transform)}"'
137
                )
138
                raise ValueError(message)
139
        return transforms_dict
140
141
    @staticmethod
142
    def _normalize_probabilities(
143
            transforms_dict: Dict[Transform, float],
144
            ) -> None:
145
        probabilities = np.array(list(transforms_dict.values()), dtype=float)
146
        if np.any(probabilities < 0):
147
            message = (
148
                'Probabilities must be greater or equal to zero,'
149
                f' not "{probabilities}"'
150
            )
151
            raise ValueError(message)
152
        if np.all(probabilities == 0):
153
            message = (
154
                'At least one probability must be greater than zero,'
155
                f' but they are "{probabilities}"'
156
            )
157
            raise ValueError(message)
158
        for transform, probability in transforms_dict.items():
159
            transforms_dict[transform] = probability / probabilities.sum()
160