Passed
Pull Request — master (#353)
by Fernando
01:05
created

Compose.__len__()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
from typing import Union, Sequence, Dict
2
3
import torch
4
import numpy as np
5
from torchvision.transforms import Compose as PyTorchCompose
6
7
from ...data.subject import Subject
8
from .. import Transform
9
from . import RandomTransform
10
11
12
TypeTransformsDict = Union[Dict[Transform, float], Sequence[Transform]]
13
14
15
class Compose(Transform):
16
    """Compose several transforms together.
17
18
    Args:
19
        transforms: Sequence of instances of
20
            :py:class:`~torchio.transforms.transform.Transform`.
21
        p: Probability that this transform will be applied.
22
23
    .. note::
24
        This is a thin wrapper of :py:class:`torchvision.transforms.Compose`.
25
    """
26
    def __init__(self, transforms: Sequence[Transform], p: float = 1):
27
        super().__init__(p=p)
28
        self.transform = PyTorchCompose(transforms)
29
        self.transforms = self.transform.transforms
30
31
    def __len__(self):
32
        return len(self.transforms)
33
34
    def __getitem__(self, index) -> Transform:
35
        return self.transforms[index]
36
37
    def __repr__(self) -> str:
38
        return self.transform.__repr__()
39
40
    def apply_transform(self, subject: Subject) -> Subject:
41
        return self.transform(subject)
42
43
    def is_invertible(self) -> bool:
44
        return all(t.is_invertible() for t in self.transforms)
45
46
    def inverse(self) -> Transform:
47
        transforms = []
48
        for transform in self.transforms:
49
            if transform.is_invertible():
50
                transforms.append(transform.inverse())
51
            else:
52
                message = f'Skipping {transform.name} as it is not invertible'
53
                raise UserWarning(message)
54
        transforms.reverse()
55
        return Compose(transforms)
56
57
58
class OneOf(RandomTransform):
59
    """Apply only one of the given transforms.
60
61
    Args:
62
        transforms: Dictionary with instances of
63
            :py:class:`~torchio.transforms.transform.Transform` as keys and
64
            probabilities as values. Probabilities are normalized so they sum
65
            to one. If a sequence is given, the same probability will be
66
            assigned to each transform.
67
        p: Probability that this transform will be applied.
68
69
    Example:
70
        >>> import torchio as tio
71
        >>> colin = tio.datasets.Colin27()
72
        >>> transforms_dict = {
73
        ...     tio.RandomAffine(): 0.75,
74
        ...     tio.RandomElasticDeformation(): 0.25,
75
        ... }  # Using 3 and 1 as probabilities would have the same effect
76
        >>> transform = torchio.transforms.OneOf(transforms_dict)
77
        >>> transformed = transform(colin)
78
79
    """
80
    def __init__(
81
            self,
82
            transforms: TypeTransformsDict,
83
            p: float = 1,
84
            ):
85
        super().__init__(p=p)
86
        self.transforms_dict = self._get_transforms_dict(transforms)
87
88
    def apply_transform(self, subject: Subject) -> Subject:
89
        weights = torch.Tensor(list(self.transforms_dict.values()))
90
        index = torch.multinomial(weights, 1)
91
        transforms = list(self.transforms_dict.keys())
92
        transform = transforms[index]
93
        transformed = transform(subject)
94
        return transformed
95
96
    def _get_transforms_dict(
97
            self,
98
            transforms: TypeTransformsDict,
99
            ) -> Dict[Transform, float]:
100
        if isinstance(transforms, dict):
101
            transforms_dict = dict(transforms)
102
            self._normalize_probabilities(transforms_dict)
103
        else:
104
            try:
105
                p = 1 / len(transforms)
106
            except TypeError as e:
107
                message = (
108
                    'Transforms argument must be a dictionary or a sequence,'
109
                    f' not {type(transforms)}'
110
                )
111
                raise ValueError(message) from e
112
            transforms_dict = {transform: p for transform in transforms}
113
        for transform in transforms_dict:
114
            if not isinstance(transform, Transform):
115
                message = (
116
                    'All keys in transform_dict must be instances of'
117
                    f'torchio.Transform, not "{type(transform)}"'
118
                )
119
                raise ValueError(message)
120
        return transforms_dict
121
122
    @staticmethod
123
    def _normalize_probabilities(
124
            transforms_dict: Dict[Transform, float],
125
            ) -> None:
126
        probabilities = np.array(list(transforms_dict.values()), dtype=float)
127
        if np.any(probabilities < 0):
128
            message = (
129
                'Probabilities must be greater or equal to zero,'
130
                f' not "{probabilities}"'
131
            )
132
            raise ValueError(message)
133
        if np.all(probabilities == 0):
134
            message = (
135
                'At least one probability must be greater than zero,'
136
                f' but they are "{probabilities}"'
137
            )
138
            raise ValueError(message)
139
        for transform, probability in transforms_dict.items():
140
            transforms_dict[transform] = probability / probabilities.sum()
141