Passed
Pull Request — master (#287)
by Fernando
01:34
created

compose_from_history()   C

Complexity

Conditions 10

Size

Total Lines 41
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 10
eloc 24
nop 1
dl 0
loc 41
rs 5.9999
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.augmentation.composition.compose_from_history() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from typing import Union, Sequence, List
2
3
import json
4
import torch
5
import torchio
6
import numpy as np
7
from torchvision.transforms import Compose as PyTorchCompose
8
9
from ...data.subject import Subject
10
from .. import Transform
11
from . import RandomTransform, Interpolation
12
13
14
class Compose(Transform):
15
    """Compose several transforms together.
16
17
    Args:
18
        transforms: Sequence of instances of
19
            :py:class:`~torchio.transforms.transform.Transform`.
20
        p: Probability that this transform will be applied.
21
22
    .. note::
23
        This is a thin wrapper of :py:class:`torchvision.transforms.Compose`.
24
    """
25
    def __init__(self, transforms: Sequence[Transform], p: float = 1):
26
        super().__init__(p=p)
27
        self.transform = PyTorchCompose(transforms)
28
29
    def apply_transform(self, subject: Subject):
30
        return self.transform(subject)
31
32
33
class OneOf(RandomTransform):
34
    """Apply only one of the given transforms.
35
36
    Args:
37
        transforms: Dictionary with instances of
38
            :py:class:`~torchio.transforms.transform.Transform` as keys and
39
            probabilities as values. Probabilities are normalized so they sum
40
            to one. If a sequence is given, the same probability will be
41
            assigned to each transform.
42
        p: Probability that this transform will be applied.
43
44
    Example:
45
        >>> import torchio as tio
46
        >>> colin = tio.datasets.Colin27()
47
        >>> transforms_dict = {
48
        ...     tio.RandomAffine(): 0.75,
49
        ...     tio.RandomElasticDeformation(): 0.25,
50
        ... }  # Using 3 and 1 as probabilities would have the same effect
51
        >>> transform = torchio.transforms.OneOf(transforms_dict)
52
        >>> transformed = transform(colin)
53
54
    """
55
    def __init__(
56
            self,
57
            transforms: Union[dict, Sequence[Transform]],
58
            p: float = 1,
59
            ):
60
        super().__init__(p=p)
61
        self.transforms_dict = self._get_transforms_dict(transforms)
62
63
    def apply_transform(self, subject: Subject):
64
        weights = torch.Tensor(list(self.transforms_dict.values()))
65
        index = torch.multinomial(weights, 1)
66
        transforms = list(self.transforms_dict.keys())
67
        transform = transforms[index]
68
        transformed = transform(subject)
69
        return transformed
70
71
    def _get_transforms_dict(self, transforms: Union[dict, Sequence]):
72
        if isinstance(transforms, dict):
73
            transforms_dict = dict(transforms)
74
            self._normalize_probabilities(transforms_dict)
75
        else:
76
            try:
77
                p = 1 / len(transforms)
78
            except TypeError as e:
79
                message = (
80
                    'Transforms argument must be a dictionary or a sequence,'
81
                    f' not {type(transforms)}'
82
                )
83
                raise ValueError(message) from e
84
            transforms_dict = {transform: p for transform in transforms}
85
        for transform in transforms_dict:
86
            if not isinstance(transform, Transform):
87
                message = (
88
                    'All keys in transform_dict must be instances of'
89
                    f'torchio.Transform, not "{type(transform)}"'
90
                )
91
                raise ValueError(message)
92
        return transforms_dict
93
94
    @staticmethod
95
    def _normalize_probabilities(transforms_dict: dict):
96
        probabilities = np.array(list(transforms_dict.values()), dtype=float)
97
        if np.any(probabilities < 0):
98
            message = (
99
                'Probabilities must be greater or equal to zero,'
100
                f' not "{probabilities}"'
101
            )
102
            raise ValueError(message)
103
        if np.all(probabilities == 0):
104
            message = (
105
                'At least one probability must be greater than zero,'
106
                f' but they are "{probabilities}"'
107
            )
108
            raise ValueError(message)
109
        for transform, probability in transforms_dict.items():
110
            transforms_dict[transform] = probability / probabilities.sum()
111
112
113
def compose_from_history(history: List):
114
    """Builds a list of transformations and seeds to reproduce a given subject's transformations from its history
115
116
    Args:
117
        history: subject history given as a list of tuples containing (transformation_name, transformation_parameters)
118
    Returns:
119
        Tuple (List of transforms, list of seeds to reproduce the transforms from the history)
120
    """
121
    trsfm_list = []
122
    seed_list = []
123
    for trsfm_name, trsfm_params in history:
124
        # No need to add the RandomDownsample since its Resampling operation is taken into account in the history
125
        if trsfm_name == 'RandomDownsample':
126
            continue
127
        # Add the seed if there is one (if the transform is random)
128
        if 'seed' in trsfm_params.keys():
129
            seed_list.append(trsfm_params['seed'])
130
        else:
131
            seed_list.append(None)
132
        # Gather all available attributes from the transformations' history
133
        # Ugly fix for RandomSwap's patch_size...
134
        trsfm_no_seed = {key: json.loads(value) if type(value) == str and value.startswith('[') else value
135
                         for key, value in trsfm_params.items() if key != 'seed'}
136
        # Special case for the interpolation as it is stored as a string in the history, a conversion is needed
137
        if 'interpolation' in trsfm_no_seed.keys():
138
            trsfm_no_seed['interpolation'] = getattr(Interpolation, trsfm_no_seed['interpolation'].split('.')[1])
139
        # Special cases when an argument is needed in the __init__
140
        if trsfm_name == 'RandomLabelsToImage':
141
            trsfm_func = getattr(torchio, trsfm_name)(label_key=trsfm_no_seed['label_key'])
142
143
        elif trsfm_name == 'Resample':
144
            if 'target' in trsfm_no_seed.keys():
145
                trsfm_func = getattr(torchio, trsfm_name)(target=trsfm_no_seed['target'])
146
            elif 'target_spacing' in trsfm_no_seed.keys():
147
                trsfm_func = getattr(torchio, trsfm_name)(target=trsfm_no_seed['target_spacing'])
148
149
        else:
150
            trsfm_func = getattr(torchio, trsfm_name)()
151
        trsfm_func.__dict__ = trsfm_no_seed
0 ignored issues
show
introduced by
The variable trsfm_func does not seem to be defined for all execution paths.
Loading history...
152
        trsfm_list.append(trsfm_func)
153
    return trsfm_list, seed_list
154