Passed
Pull Request — master (#353)
by Fernando
01:16
created

torchio.transforms.augmentation.composition.compose_from_history()   C

Complexity

Conditions 10

Size

Total Lines 41
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 10
eloc 24
nop 1
dl 0
loc 41
rs 5.9999
c 0
b 0
f 0

1 Method

Rating   Name   Duplication   Size   Complexity  
A torchio.transforms.augmentation.composition.OneOf._normalize_probabilities() 0 19 4

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.augmentation.composition.compose_from_history() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from typing import Union, Sequence, Dict
2
3
import torch
4
import numpy as np
5
from torchvision.transforms import Compose as PyTorchCompose
6
7
from ...data.subject import Subject
8
from .. import Transform
9
from . import RandomTransform
10
11
12
TypeTransformsDict = Union[Dict[Transform, float], Sequence[Transform]]
13
14
15
class Compose(Transform):
16
    """Compose several transforms together.
17
18
    Args:
19
        transforms: Sequence of instances of
20
            :py:class:`~torchio.transforms.transform.Transform`.
21
        p: Probability that this transform will be applied.
22
23
    .. note::
24
        This is a thin wrapper of :py:class:`torchvision.transforms.Compose`.
25
    """
26
    def __init__(self, transforms: Sequence[Transform], p: float = 1):
27
        super().__init__(p=p)
28
        if not transforms:
29
            raise ValueError('The list of transforms is empty')
30
        self.transform = PyTorchCompose(transforms)
31
        self.transforms = self.transform.transforms
32
33
    def __len__(self):
34
        return len(self.transforms)
35
36
    def __getitem__(self, index) -> Transform:
37
        return self.transforms[index]
38
39
    def __repr__(self) -> str:
40
        return self.transform.__repr__()
41
42
    def apply_transform(self, subject: Subject) -> Subject:
43
        return self.transform(subject)
44
45
    def is_invertible(self) -> bool:
46
        return all(t.is_invertible() for t in self.transforms)
47
48
    def inverse(self) -> Transform:
49
        transforms = []
50
        for transform in self.transforms:
51
            if transform.is_invertible():
52
                transforms.append(transform.inverse())
53
            else:
54
                message = f'Skipping {transform.name} as it is not invertible'
55
                raise UserWarning(message)
56
        transforms.reverse()
57
        return Compose(transforms)
58
59
60
class OneOf(RandomTransform):
61
    """Apply only one of the given transforms.
62
63
    Args:
64
        transforms: Dictionary with instances of
65
            :py:class:`~torchio.transforms.transform.Transform` as keys and
66
            probabilities as values. Probabilities are normalized so they sum
67
            to one. If a sequence is given, the same probability will be
68
            assigned to each transform.
69
        p: Probability that this transform will be applied.
70
71
    Example:
72
        >>> import torchio as tio
73
        >>> colin = tio.datasets.Colin27()
74
        >>> transforms_dict = {
75
        ...     tio.RandomAffine(): 0.75,
76
        ...     tio.RandomElasticDeformation(): 0.25,
77
        ... }  # Using 3 and 1 as probabilities would have the same effect
78
        >>> transform = torchio.transforms.OneOf(transforms_dict)
79
        >>> transformed = transform(colin)
80
81
    """
82
    def __init__(
83
            self,
84
            transforms: TypeTransformsDict,
85
            p: float = 1,
86
            ):
87
        super().__init__(p=p)
88
        self.transforms_dict = self._get_transforms_dict(transforms)
89
90
    def apply_transform(self, subject: Subject) -> Subject:
91
        weights = torch.Tensor(list(self.transforms_dict.values()))
92
        index = torch.multinomial(weights, 1)
93
        transforms = list(self.transforms_dict.keys())
94
        transform = transforms[index]
95
        transformed = transform(subject)
96
        return transformed
97
98
    def _get_transforms_dict(
99
            self,
100
            transforms: TypeTransformsDict,
101
            ) -> Dict[Transform, float]:
102
        if isinstance(transforms, dict):
103
            transforms_dict = dict(transforms)
104
            self._normalize_probabilities(transforms_dict)
105
        else:
106
            try:
107
                p = 1 / len(transforms)
108
            except TypeError as e:
109
                message = (
110
                    'Transforms argument must be a dictionary or a sequence,'
111
                    f' not {type(transforms)}'
112
                )
113
                raise ValueError(message) from e
114
            transforms_dict = {transform: p for transform in transforms}
115
        for transform in transforms_dict:
116
            if not isinstance(transform, Transform):
117
                message = (
118
                    'All keys in transform_dict must be instances of'
119
                    f'torchio.Transform, not "{type(transform)}"'
120
                )
121
                raise ValueError(message)
122
        return transforms_dict
123
124
    @staticmethod
125
    def _normalize_probabilities(
126
            transforms_dict: Dict[Transform, float],
127
            ) -> None:
128
        probabilities = np.array(list(transforms_dict.values()), dtype=float)
129
        if np.any(probabilities < 0):
130
            message = (
131
                'Probabilities must be greater or equal to zero,'
132
                f' not "{probabilities}"'
133
            )
134
            raise ValueError(message)
135
        if np.all(probabilities == 0):
136
            message = (
137
                'At least one probability must be greater than zero,'
138
                f' but they are "{probabilities}"'
139
            )
140
            raise ValueError(message)
141
        for transform, probability in transforms_dict.items():
142
            transforms_dict[transform] = probability / probabilities.sum()
143