Passed
Pull Request — master (#287)
by Fernando
01:18
created

OneOf.__init__()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 3
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
from typing import Union, Sequence, List, Optional
2
3
import torch
4
import torchio
5
import numpy as np
6
from torchvision.transforms import Compose as PyTorchCompose
7
8
from ...data.subject import Subject
9
from ...utils import gen_seed
10
from .. import Transform
11
from . import RandomTransform
12
13
14
class Compose(Transform):
15
    """Compose several transforms together.
16
17
    Args:
18
        transforms: Sequence of instances of
19
            :py:class:`~torchio.transforms.transform.Transform`.
20
        p: Probability that this transform will be applied.
21
22
    .. note::
23
        This is a thin wrapper of :py:class:`torchvision.transforms.Compose`.
24
    """
25
    def __init__(
26
            self,
27
            transforms: Optional[Sequence[Transform]] = None,
28
            p: float = 1,
29
            ):
30
        super().__init__(p=p)
31
        transforms = [] if transforms is None else transforms
32
        self.transform = PyTorchCompose(transforms)
33
34
    def __call__(self, data: Union[Subject, torch.Tensor, np.ndarray], seeds: List = None):
35
        if not self.transform.transforms:
36
            return data
37
38
        if not seeds:
39
            seeds = [gen_seed() for _ in range(len(self.transform.transforms))]
40
        self.seeds = seeds
41
        return super().__call__(data, seeds)
42
43
    def apply_transform(self, sample: Subject):
44
        for t, s in zip(self.transform.transforms, self.seeds):
45
            sample = t(sample, s)
46
        return sample
47
48
49
class OneOf(RandomTransform):
50
    """Apply only one of the given transforms.
51
52
    Args:
53
        transforms: Dictionary with instances of
54
            :py:class:`~torchio.transforms.transform.Transform` as keys and
55
            probabilities as values. Probabilities are normalized so they sum
56
            to one. If a sequence is given, the same probability will be
57
            assigned to each transform.
58
        p: Probability that this transform will be applied.
59
60
    Example:
61
        >>> import torchio as tio
62
        >>> colin = tio.datasets.Colin27()
63
        >>> transforms_dict = {
64
        ...     tio.RandomAffine(): 0.75,
65
        ...     tio.RandomElasticDeformation(): 0.25,
66
        ... }  # Using 3 and 1 as probabilities would have the same effect
67
        >>> transform = torchio.transforms.OneOf(transforms_dict)
68
        >>> transformed = transform(colin)
69
70
    """
71
    def __init__(
72
            self,
73
            transforms: Union[dict, Sequence[Transform]],
74
            p: float = 1,
75
            ):
76
        super().__init__(p=p)
77
        self.transforms_dict = self._get_transforms_dict(transforms)
78
79
    def apply_transform(self, subject: Subject):
80
        weights = torch.Tensor(list(self.transforms_dict.values()))
81
        index = torch.multinomial(weights, 1)
82
        transforms = list(self.transforms_dict.keys())
83
        transform = transforms[index]
84
        transformed = transform(subject)
85
        return transformed
86
87
    def _get_transforms_dict(self, transforms: Union[dict, Sequence]):
88
        if isinstance(transforms, dict):
89
            transforms_dict = dict(transforms)
90
            self._normalize_probabilities(transforms_dict)
91
        else:
92
            try:
93
                p = 1 / len(transforms)
94
            except TypeError as e:
95
                message = (
96
                    'Transforms argument must be a dictionary or a sequence,'
97
                    f' not {type(transforms)}'
98
                )
99
                raise ValueError(message) from e
100
            transforms_dict = {transform: p for transform in transforms}
101
        for transform in transforms_dict:
102
            if not isinstance(transform, Transform):
103
                message = (
104
                    'All keys in transform_dict must be instances of'
105
                    f'torchio.Transform, not "{type(transform)}"'
106
                )
107
                raise ValueError(message)
108
        return transforms_dict
109
110
    @staticmethod
111
    def _normalize_probabilities(transforms_dict: dict):
112
        probabilities = np.array(list(transforms_dict.values()), dtype=float)
113
        if np.any(probabilities < 0):
114
            message = (
115
                'Probabilities must be greater or equal to zero,'
116
                f' not "{probabilities}"'
117
            )
118
            raise ValueError(message)
119
        if np.all(probabilities == 0):
120
            message = (
121
                'At least one probability must be greater than zero,'
122
                f' but they are "{probabilities}"'
123
            )
124
            raise ValueError(message)
125
        for transform, probability in transforms_dict.items():
126
            transforms_dict[transform] = probability / probabilities.sum()
127
128
129
def compose_from_history(history: List):
130
    """
131
    Builds a composition of transformations from a given subject history
132
    :param history: subject history given as a list of tuples containing (transformation_name, transformation_parameters)
133
    :return: Tuple (Compose of transforms, list of seeds to reproduce the transforms from the history)
134
    """
135
    trsfm_list = []
136
    seed_list = []
137
    for trsfm_history in history:
138
        trsfm_name, trsfm_params = trsfm_history[0], (trsfm_history[1])
139
        seed_list.append(trsfm_params['seed'])
140
        trsfm_no_seed = {key: value for key, value in trsfm_params.items() if key != 'seed'}
141
        trsfm_func = getattr(torchio, trsfm_name)()
142
        trsfm_func.__dict__ = trsfm_no_seed
143
        trsfm_list.append(trsfm_func)
144
    return Compose(trsfm_list), seed_list
145