Passed
Pull Request — master (#353)
by Fernando
01:16
created

torchio.transforms.augmentation.composition.compose_from_history()   C

Complexity

Conditions 10

Size

Total Lines 41
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 10
eloc 24
nop 1
dl 0
loc 41
rs 5.9999
c 0
b 0
f 0

1 Method

Rating   Name   Duplication   Size   Complexity  
A torchio.transforms.augmentation.composition.OneOf._normalize_probabilities() 0 19 4

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.augmentation.composition.compose_from_history() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import warnings
2
from typing import Union, Sequence, Dict
3
4
import torch
5
import numpy as np
6
from torchvision.transforms import Compose as PyTorchCompose
7
8
from ...data.subject import Subject
9
from .. import Transform
10
from . import RandomTransform
11
12
13
TypeTransformsDict = Union[Dict[Transform, float], Sequence[Transform]]
14
15
16
class Compose(Transform):
17
    """Compose several transforms together.
18
19
    Args:
20
        transforms: Sequence of instances of
21
            :py:class:`~torchio.transforms.transform.Transform`.
22
        p: Probability that this transform will be applied.
23
24
    .. note::
25
        This is a thin wrapper of :py:class:`torchvision.transforms.Compose`.
26
    """
27
    def __init__(self, transforms: Sequence[Transform], p: float = 1):
28
        super().__init__(p=p)
29
        if not transforms:
30
            raise ValueError('The list of transforms is empty')
31
        self.transform = PyTorchCompose(transforms)
32
        self.transforms = self.transform.transforms
33
34
    def __len__(self):
35
        return len(self.transforms)
36
37
    def __getitem__(self, index) -> Transform:
38
        return self.transforms[index]
39
40
    def __repr__(self) -> str:
41
        return self.transform.__repr__()
42
43
    def apply_transform(self, subject: Subject) -> Subject:
44
        return self.transform(subject)
45
46
    def is_invertible(self) -> bool:
47
        return all(t.is_invertible() for t in self.transforms)
48
49
    def inverse(self) -> Transform:
50
        transforms = []
51
        for transform in self.transforms:
52
            if transform.is_invertible():
53
                transforms.append(transform.inverse())
54
            else:
55
                message = f'Skipping {transform.name} as it is not invertible'
56
                warnings.warn(message, UserWarning)
57
        transforms.reverse()
58
        return Compose(transforms)
59
60
61
class OneOf(RandomTransform):
62
    """Apply only one of the given transforms.
63
64
    Args:
65
        transforms: Dictionary with instances of
66
            :py:class:`~torchio.transforms.transform.Transform` as keys and
67
            probabilities as values. Probabilities are normalized so they sum
68
            to one. If a sequence is given, the same probability will be
69
            assigned to each transform.
70
        p: Probability that this transform will be applied.
71
72
    Example:
73
        >>> import torchio as tio
74
        >>> colin = tio.datasets.Colin27()
75
        >>> transforms_dict = {
76
        ...     tio.RandomAffine(): 0.75,
77
        ...     tio.RandomElasticDeformation(): 0.25,
78
        ... }  # Using 3 and 1 as probabilities would have the same effect
79
        >>> transform = torchio.transforms.OneOf(transforms_dict)
80
        >>> transformed = transform(colin)
81
82
    """
83
    def __init__(
84
            self,
85
            transforms: TypeTransformsDict,
86
            p: float = 1,
87
            ):
88
        super().__init__(p=p)
89
        self.transforms_dict = self._get_transforms_dict(transforms)
90
91
    def apply_transform(self, subject: Subject) -> Subject:
92
        weights = torch.Tensor(list(self.transforms_dict.values()))
93
        index = torch.multinomial(weights, 1)
94
        transforms = list(self.transforms_dict.keys())
95
        transform = transforms[index]
96
        transformed = transform(subject)
97
        return transformed
98
99
    def _get_transforms_dict(
100
            self,
101
            transforms: TypeTransformsDict,
102
            ) -> Dict[Transform, float]:
103
        if isinstance(transforms, dict):
104
            transforms_dict = dict(transforms)
105
            self._normalize_probabilities(transforms_dict)
106
        else:
107
            try:
108
                p = 1 / len(transforms)
109
            except TypeError as e:
110
                message = (
111
                    'Transforms argument must be a dictionary or a sequence,'
112
                    f' not {type(transforms)}'
113
                )
114
                raise ValueError(message) from e
115
            transforms_dict = {transform: p for transform in transforms}
116
        for transform in transforms_dict:
117
            if not isinstance(transform, Transform):
118
                message = (
119
                    'All keys in transform_dict must be instances of'
120
                    f'torchio.Transform, not "{type(transform)}"'
121
                )
122
                raise ValueError(message)
123
        return transforms_dict
124
125
    @staticmethod
126
    def _normalize_probabilities(
127
            transforms_dict: Dict[Transform, float],
128
            ) -> None:
129
        probabilities = np.array(list(transforms_dict.values()), dtype=float)
130
        if np.any(probabilities < 0):
131
            message = (
132
                'Probabilities must be greater or equal to zero,'
133
                f' not "{probabilities}"'
134
            )
135
            raise ValueError(message)
136
        if np.all(probabilities == 0):
137
            message = (
138
                'At least one probability must be greater than zero,'
139
                f' but they are "{probabilities}"'
140
            )
141
            raise ValueError(message)
142
        for transform, probability in transforms_dict.items():
143
            transforms_dict[transform] = probability / probabilities.sum()
144