Passed
Push — master ( db49f2...b2e640 )
by Fernando
01:27
created

Compose.__init__()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
from typing import Union, Sequence
2
3
import torch
0 ignored issues
show
introduced by
Unable to import 'torch'
Loading history...
4
import numpy as np
0 ignored issues
show
introduced by
Unable to import 'numpy'
Loading history...
5
from torchvision.transforms import Compose as PyTorchCompose
0 ignored issues
show
introduced by
Unable to import 'torchvision.transforms'
Loading history...
6
7
from .. import Transform
8
from . import RandomTransform
9
10
11
class Compose(Transform):
12
    """Compose several transforms together.
13
14
    Args:
15
        transforms: list of instances of
16
            :py:class:`~torchio.transforms.transform.Transform`.
17
        p: Probability that this transform will be applied.
18
19
    .. note::
20
        This is a thin wrapper of :py:class:`torchvision.transforms.Compose`.
21
    """
22
    def __init__(self, transforms: Sequence[Transform], p: float = 1):
23
        super().__init__(p=p)
24
        self.transform = PyTorchCompose(transforms)
25
26
    def apply_transform(self, sample: dict):
27
        return self.transform(sample)
28
29
30
class OneOf(RandomTransform):
31
    """Apply only one of the given transforms.
32
33
    Args:
34
        transforms: Dictionary with instances of
35
            :py:class:`~torchio.transforms.transform.Transform` as keys and
36
            probabilities as values. Probabilities are normalized so they sum
37
            to one. If a sequence is given, the same probability will be
38
            assigned to each transform.
39
        p: Probability that this transform will be applied.
40
41
    Example:
42
        >>> import torchio
43
        >>> ixi = torchio.datasets.ixi.IXITiny('ixi', download=True)
44
        >>> sample = ixi[0]
45
        >>> transforms_dict = {
46
        ...     torchio.transforms.RandomAffine(): 0.75,
47
        ...     torchio.transforms.RandomElasticDeformation(): 0.25,
48
        ... }  # Using 3 and 1 as probabilities would have the same effect
49
        >>> transform = torchio.transforms.OneOf(transforms_dict)
50
51
    """
52
    def __init__(self, transforms: Union[dict, Sequence], p: float = 1):
53
        super().__init__(p=p)
54
        self.transforms_dict = self._get_transforms_dict(transforms)
55
56
    def apply_transform(self, sample: dict):
57
        weights = torch.Tensor(list(self.transforms_dict.values()))
58
        index = torch.multinomial(weights, 1)
59
        transforms = list(self.transforms_dict.keys())
60
        transform = transforms[index]
61
        transformed = transform(sample)
62
        return transformed
63
64
    def _get_transforms_dict(self, transforms: Union[dict, Sequence]):
65
        if isinstance(transforms, dict):
66
            transforms_dict = dict(transforms)
67
            self._normalize_probabilities(transforms_dict)
68
        else:
69
            try:
70
                p = 1 / len(transforms)
0 ignored issues
show
Coding Style Naming introduced by
Variable name "p" doesn't conform to snake_case naming style ('(([a-z_][a-z0-9_]2,)|(_[a-z0-9_]*)|(__[a-z][a-z0-9_]+__))$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
71
            except TypeError as e:
0 ignored issues
show
Coding Style Naming introduced by
Variable name "e" doesn't conform to snake_case naming style ('(([a-z_][a-z0-9_]2,)|(_[a-z0-9_]*)|(__[a-z][a-z0-9_]+__))$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
72
                message = (
73
                    'Transforms argument must be a dictionary or a sequence,'
74
                    f' not {type(transforms)}'
75
                )
76
                raise ValueError(message) from e
77
            transforms_dict = {transform: p for transform in transforms}
78
        for transform in transforms_dict:
79
            if not isinstance(transform, Transform):
80
                message = (
81
                    'All keys in transform_dict must be instances of'
82
                    f'torchio.Transform, not "{type(transform)}"'
83
                )
84
                raise ValueError(message)
85
        return transforms_dict
86
87
    @staticmethod
88
    def _normalize_probabilities(transforms_dict: dict):
89
        probabilities = np.array(list(transforms_dict.values()), dtype=float)
90
        if np.any(probabilities < 0):
91
            message = (
92
                'Probabilities must be greater or equal to zero,'
93
                f' not "{probabilities}"'
94
            )
95
            raise ValueError(message)
96
        if np.all(probabilities == 0):
97
            message = (
98
                'At least one probability must be greater than zero,'
99
                f' but they are "{probabilities}"'
100
            )
101
            raise ValueError(message)
102
        for transform, probability in transforms_dict.items():
103
            transforms_dict[transform] = probability / probabilities.sum()
104