Passed
Pull Request — master (#287)
by Fernando
01:37
created

Compose.__call__()   A

Complexity

Conditions 3

Size

Total Lines 8
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 7
nop 3
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
from typing import Union, Sequence, List
2
3
import torch
4
import torchio
5
import numpy as np
6
from torchvision.transforms import Compose as PyTorchCompose
7
8
from ...data.subject import Subject
9
from ...utils import gen_seed
10
from .. import Transform
11
from . import RandomTransform
12
13
14
class Compose(Transform):
15
    """Compose several transforms together.
16
17
    Args:
18
        transforms: Sequence of instances of
19
            :py:class:`~torchio.transforms.transform.Transform`.
20
        p: Probability that this transform will be applied.
21
22
    .. note::
23
        This is a thin wrapper of :py:class:`torchvision.transforms.Compose`.
24
    """
25
    def __init__(self, transforms: Sequence[Transform] = [], p: float = 1):
26
        super().__init__(p=p)
27
        self.transform = PyTorchCompose(transforms)
28
29
    def __call__(self, data: Union[Subject, torch.Tensor, np.ndarray], seeds: List = None):
30
        if not self.transform.transforms:
31
            return data
32
33
        if not seeds:
34
            seeds = [gen_seed() for _ in range(len(self.transform.transforms))]
35
        self.seeds = seeds
36
        return super(Compose, self).__call__(data, seeds)
37
38
    def apply_transform(self, sample: Subject):
39
        for t, s in zip(self.transform.transforms, self.seeds):
40
            sample = t(sample, s)
41
        return sample
42
43
44
class OneOf(RandomTransform):
45
    """Apply only one of the given transforms.
46
47
    Args:
48
        transforms: Dictionary with instances of
49
            :py:class:`~torchio.transforms.transform.Transform` as keys and
50
            probabilities as values. Probabilities are normalized so they sum
51
            to one. If a sequence is given, the same probability will be
52
            assigned to each transform.
53
        p: Probability that this transform will be applied.
54
55
    Example:
56
        >>> import torchio as tio
57
        >>> colin = tio.datasets.Colin27()
58
        >>> transforms_dict = {
59
        ...     tio.RandomAffine(): 0.75,
60
        ...     tio.RandomElasticDeformation(): 0.25,
61
        ... }  # Using 3 and 1 as probabilities would have the same effect
62
        >>> transform = torchio.transforms.OneOf(transforms_dict)
63
        >>> transformed = transform(colin)
64
65
    """
66
    def __init__(
67
            self,
68
            transforms: Union[dict, Sequence[Transform]],
69
            p: float = 1,
70
            ):
71
        super().__init__(p=p)
72
        self.transforms_dict = self._get_transforms_dict(transforms)
73
74
    def apply_transform(self, subject: Subject):
75
        weights = torch.Tensor(list(self.transforms_dict.values()))
76
        index = torch.multinomial(weights, 1)
77
        transforms = list(self.transforms_dict.keys())
78
        transform = transforms[index]
79
        transformed = transform(subject)
80
        return transformed
81
82
    def _get_transforms_dict(self, transforms: Union[dict, Sequence]):
83
        if isinstance(transforms, dict):
84
            transforms_dict = dict(transforms)
85
            self._normalize_probabilities(transforms_dict)
86
        else:
87
            try:
88
                p = 1 / len(transforms)
89
            except TypeError as e:
90
                message = (
91
                    'Transforms argument must be a dictionary or a sequence,'
92
                    f' not {type(transforms)}'
93
                )
94
                raise ValueError(message) from e
95
            transforms_dict = {transform: p for transform in transforms}
96
        for transform in transforms_dict:
97
            if not isinstance(transform, Transform):
98
                message = (
99
                    'All keys in transform_dict must be instances of'
100
                    f'torchio.Transform, not "{type(transform)}"'
101
                )
102
                raise ValueError(message)
103
        return transforms_dict
104
105
    @staticmethod
106
    def _normalize_probabilities(transforms_dict: dict):
107
        probabilities = np.array(list(transforms_dict.values()), dtype=float)
108
        if np.any(probabilities < 0):
109
            message = (
110
                'Probabilities must be greater or equal to zero,'
111
                f' not "{probabilities}"'
112
            )
113
            raise ValueError(message)
114
        if np.all(probabilities == 0):
115
            message = (
116
                'At least one probability must be greater than zero,'
117
                f' but they are "{probabilities}"'
118
            )
119
            raise ValueError(message)
120
        for transform, probability in transforms_dict.items():
121
            transforms_dict[transform] = probability / probabilities.sum()
122
123
124
def compose_from_history(history: List):
125
    """
126
    Builds a composition of transformations from a given subject history
127
    :param history: subject history given as a list of tuples containing (transformation_name, transformation_parameters)
128
    :return: Tuple (Compose of transforms, list of seeds to reproduce the transforms from the history)
129
    """
130
    trsfm_list = []
131
    seed_list = []
132
    for trsfm_history in history:
133
        trsfm_name, trsfm_params = trsfm_history[0], (trsfm_history[1])
134
        seed_list.append(trsfm_params['seed'])
135
        trsfm_no_seed = {key: value for key, value in trsfm_params.items() if key != "seed"}
136
        trsfm_func = getattr(torchio, trsfm_name)()
137
        trsfm_func.__dict__ = trsfm_no_seed
138
        trsfm_list.append(trsfm_func)
139
    return Compose(trsfm_list), seed_list
140