Passed
Pull Request — master (#353)
by Fernando
01:09
created

Noise.get_arguments()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
import copy
2
from collections import defaultdict
3
from typing import Tuple, Optional, Union, List, Dict, Sequence
4
5
import torch
6
from ....torchio import DATA
7
from ....utils import get_random_seed
8
from ....data.subject import Subject
9
from ... import IntensityTransform
10
from .. import RandomTransform
11
12
13
class RandomNoise(RandomTransform, IntensityTransform):
14
    r"""Add Gaussian noise with random parameters.
15
16
    Add noise sampled from a normal distribution with random parameters.
17
18
    Args:
19
        mean: Mean :math:`\mu` of the Gaussian distribution
20
            from which the noise is sampled.
21
            If two values :math:`(a, b)` are provided,
22
            then :math:`\mu \sim \mathcal{U}(a, b)`.
23
            If only one value :math:`d` is provided,
24
            :math:`\mu \sim \mathcal{U}(-d, d)`.
25
        std: Standard deviation :math:`\sigma` of the Gaussian distribution
26
            from which the noise is sampled.
27
            If two values :math:`(a, b)` are provided,
28
            then :math:`\sigma \sim \mathcal{U}(a, b)`.
29
            If only one value :math:`d` is provided,
30
            :math:`\sigma \sim \mathcal{U}(0, d)`.
31
        p: Probability that this transform will be applied.
32
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
33
        keys: See :py:class:`~torchio.transforms.Transform`.
34
    """
35
    def __init__(
36
            self,
37
            mean: Union[float, Tuple[float, float]] = 0,
38
            std: Union[float, Tuple[float, float]] = (0, 0.25),
39
            p: float = 1,
40
            keys: Optional[Sequence[str]] = None,
41
            ):
42
        super().__init__(p=p, keys=keys)
43
        self.mean_range = self.parse_range(mean, 'mean')
44
        self.std_range = self.parse_range(std, 'std', min_constraint=0)
45
46
    def apply_transform(self, subject: Subject) -> Subject:
47
        arguments = defaultdict(dict)
48
        for image_name in self.get_images_dict(subject):
49
            mean, std, seed = self.get_params(self.mean_range, self.std_range)
50
            arguments['mean'][image_name] = mean
51
            arguments['std'][image_name] = std
52
            arguments['seed'][image_name] = seed
53
        transform = Noise(**arguments)
54
        transformed = transform(subject)
55
        return transformed
56
57
    @staticmethod
58
    def get_params(
59
            mean_range: Tuple[float, float],
60
            std_range: Tuple[float, float],
61
            ) -> Tuple[float, float]:
62
        mean = torch.FloatTensor(1).uniform_(*mean_range).item()
63
        std = torch.FloatTensor(1).uniform_(*std_range).item()
64
        seed = get_random_seed()
65
        return mean, std, seed
66
67
68
class Noise(IntensityTransform):
69
    r"""Add Gaussian noise.
70
71
    Add noise sampled from a normal distribution.
72
73
    Args:
74
        mean: Mean :math:`\mu` of the Gaussian distribution
75
            from which the noise is sampled.
76
        std: Standard deviation :math:`\sigma` of the Gaussian distribution
77
            from which the noise is sampled.
78
        seed:
79
        keys: See :py:class:`~torchio.transforms.Transform`.
80
    """
81
    def __init__(
82
            self,
83
            mean: Union[float, Dict[str, float]],
84
            std: Union[float, Dict[str, float]],
85
            seed: Union[int, Sequence[int]],
86
            keys: Optional[List[str]] = None,
87
            ):
88
        super().__init__(keys=keys)
89
        self.mean = mean
90
        self.std = std
91
        self.seed = seed
92
        self._has_dicts = self.parse_arguments()
93
        self.invert_transform = False
94
95
    def get_arguments(self):
96
        return {'mean': self.mean, 'std': self.std, 'seed': self.seed}
97
98
    def parse_arguments(self):
99
        mean_dict = isinstance(self.mean, dict)
100
        std_dict = isinstance(self.std, dict)
101
        seed_dict = isinstance(self.seed, dict)
102
        three_bools = mean_dict, std_dict, seed_dict
103
        if all(three_bools):
104
            return True
105
        elif not any(three_bools):
106
            return False
107
        else:
108
            message = 'All arguments must have the same type: float or dict'
109
            raise ValueError(message)
110
111
    def apply_transform(self, subject: Subject) -> Subject:
112
        args = self.mean, self.std, self.seed
113
        for name, image in self.get_images_dict(subject).items():
114
            if self._has_dicts:
115
                args = self.mean[name], self.std[name], self.seed[name]
116
            noise = get_noise(image[DATA], *args)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
117
            if self.invert_transform:
118
                noise *= -1
119
            image[DATA] = image[DATA] + noise
120
        return subject
121
122
    def inverse(self):
123
        new = copy.deepcopy(self)
124
        new.invert_transform = not self.invert_transform
125
        return new
126
127
128
def get_randn(shape: Sequence[int], seed: int) -> torch.Tensor:
129
    torch_rng_state = torch.random.get_rng_state()
130
    torch.manual_seed(seed)
131
    noise = torch.randn(*shape)
132
    torch.random.set_rng_state(torch_rng_state)
133
    return noise
134
135
136
def get_noise(
137
        tensor: torch.Tensor,
138
        mean: float,
139
        std: float,
140
        seed: Optional[int] = None) -> torch.Tensor:
141
    return get_randn(tensor.shape, seed=seed) * std + mean
142