Passed
Pull Request — master (#353)
by Fernando
01:15
created

torchio.transforms.augmentation.composition.compose_from_history()   C

Complexity

Conditions 10

Size

Total Lines 41
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 10
eloc 24
nop 1
dl 0
loc 41
rs 5.9999
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.augmentation.composition.compose_from_history() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from typing import Union, Sequence
2
3
import torch
4
import numpy as np
5
from torchvision.transforms import Compose as PyTorchCompose
6
7
from ...data.subject import Subject
8
from .. import Transform
9
from . import RandomTransform
10
11
12
class Compose(Transform):
13
    """Compose several transforms together.
14
15
    Args:
16
        transforms: Sequence of instances of
17
            :py:class:`~torchio.transforms.transform.Transform`.
18
        p: Probability that this transform will be applied.
19
20
    .. note::
21
        This is a thin wrapper of :py:class:`torchvision.transforms.Compose`.
22
    """
23
    def __init__(self, transforms: Sequence[Transform], p: float = 1):
24
        super().__init__(p=p)
25
        self.transform = PyTorchCompose(transforms)
26
        self.transforms = self.transform.transforms
27
28
    def __len__(self):
29
        return len(self.transforms)
30
31
    def __getitem__(self, index):
32
        return self.transforms[index]
33
34
    def __repr__(self):
35
        return self.transform.__repr__()
36
37
    def apply_transform(self, subject: Subject):
38
        return self.transform(subject)
39
40
    def is_invertible(self):
41
        return all(t.is_invertible() for t in self.transforms)
42
43
    def inverse(self):
44
        return Compose(reversed(t.inverse() for t in self.transforms))
45
46
47
class OneOf(RandomTransform):
48
    """Apply only one of the given transforms.
49
50
    Args:
51
        transforms: Dictionary with instances of
52
            :py:class:`~torchio.transforms.transform.Transform` as keys and
53
            probabilities as values. Probabilities are normalized so they sum
54
            to one. If a sequence is given, the same probability will be
55
            assigned to each transform.
56
        p: Probability that this transform will be applied.
57
58
    Example:
59
        >>> import torchio as tio
60
        >>> colin = tio.datasets.Colin27()
61
        >>> transforms_dict = {
62
        ...     tio.RandomAffine(): 0.75,
63
        ...     tio.RandomElasticDeformation(): 0.25,
64
        ... }  # Using 3 and 1 as probabilities would have the same effect
65
        >>> transform = torchio.transforms.OneOf(transforms_dict)
66
        >>> transformed = transform(colin)
67
68
    """
69
    def __init__(
70
            self,
71
            transforms: Union[dict, Sequence[Transform]],
72
            p: float = 1,
73
            ):
74
        super().__init__(p=p)
75
        self.transforms_dict = self._get_transforms_dict(transforms)
76
77
    def apply_transform(self, subject: Subject):
78
        weights = torch.Tensor(list(self.transforms_dict.values()))
79
        index = torch.multinomial(weights, 1)
80
        transforms = list(self.transforms_dict.keys())
81
        transform = transforms[index]
82
        transformed = transform(subject)
83
        return transformed
84
85
    def _get_transforms_dict(self, transforms: Union[dict, Sequence]):
86
        if isinstance(transforms, dict):
87
            transforms_dict = dict(transforms)
88
            self._normalize_probabilities(transforms_dict)
89
        else:
90
            try:
91
                p = 1 / len(transforms)
92
            except TypeError as e:
93
                message = (
94
                    'Transforms argument must be a dictionary or a sequence,'
95
                    f' not {type(transforms)}'
96
                )
97
                raise ValueError(message) from e
98
            transforms_dict = {transform: p for transform in transforms}
99
        for transform in transforms_dict:
100
            if not isinstance(transform, Transform):
101
                message = (
102
                    'All keys in transform_dict must be instances of'
103
                    f'torchio.Transform, not "{type(transform)}"'
104
                )
105
                raise ValueError(message)
106
        return transforms_dict
107
108
    @staticmethod
109
    def _normalize_probabilities(transforms_dict: dict):
110
        probabilities = np.array(list(transforms_dict.values()), dtype=float)
111
        if np.any(probabilities < 0):
112
            message = (
113
                'Probabilities must be greater or equal to zero,'
114
                f' not "{probabilities}"'
115
            )
116
            raise ValueError(message)
117
        if np.all(probabilities == 0):
118
            message = (
119
                'At least one probability must be greater than zero,'
120
                f' but they are "{probabilities}"'
121
            )
122
            raise ValueError(message)
123
        for transform, probability in transforms_dict.items():
124
            transforms_dict[transform] = probability / probabilities.sum()
125