Passed
Pull Request — master (#353)
by Fernando
01:03
created

RandomNoise.__init__()   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 9
nop 5
dl 0
loc 10
rs 9.95
c 0
b 0
f 0
1
import copy
2
from collections import defaultdict
3
from typing import Tuple, Optional, Union, List, Dict, Sequence
4
5
import torch
6
from ....torchio import DATA
7
from ....utils import get_random_seed
8
from ....data.subject import Subject
9
from ... import IntensityTransform
10
from .. import RandomTransform
11
12
13
class RandomNoise(RandomTransform, IntensityTransform):
14
    r"""Add Gaussian noise with random parameters.
15
16
    Add noise sampled from a normal distribution with random parameters.
17
18
    Args:
19
        mean: Mean :math:`\mu` of the Gaussian distribution
20
            from which the noise is sampled.
21
            If two values :math:`(a, b)` are provided,
22
            then :math:`\mu \sim \mathcal{U}(a, b)`.
23
            If only one value :math:`d` is provided,
24
            :math:`\mu \sim \mathcal{U}(-d, d)`.
25
        std: Standard deviation :math:`\sigma` of the Gaussian distribution
26
            from which the noise is sampled.
27
            If two values :math:`(a, b)` are provided,
28
            then :math:`\sigma \sim \mathcal{U}(a, b)`.
29
            If only one value :math:`d` is provided,
30
            :math:`\sigma \sim \mathcal{U}(0, d)`.
31
        p: Probability that this transform will be applied.
32
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
33
        keys: See :py:class:`~torchio.transforms.Transform`.
34
    """
35
    def __init__(
36
            self,
37
            mean: Union[float, Tuple[float, float]] = 0,
38
            std: Union[float, Tuple[float, float]] = (0, 0.25),
39
            p: float = 1,
40
            keys: Optional[Sequence[str]] = None,
41
            ):
42
        super().__init__(p=p, keys=keys)
43
        self.mean_range = self.parse_range(mean, 'mean')
44
        self.std_range = self.parse_range(std, 'std', min_constraint=0)
45
46
    def apply_transform(self, subject: Subject) -> Subject:
47
        arguments = defaultdict(dict)
48
        for image_name in self.get_images_dict(subject):
49
            mean, std, seed = self.get_params(self.mean_range, self.std_range)
50
            arguments['mean'][image_name] = mean
51
            arguments['std'][image_name] = std
52
            arguments['seed'][image_name] = seed
53
        transform = Noise(**arguments)
54
        transformed = transform(subject)
55
        return transformed
56
57
    @staticmethod
58
    def get_params(
59
            mean_range: Tuple[float, float],
60
            std_range: Tuple[float, float],
61
            ) -> Tuple[float, float]:
62
        mean = torch.FloatTensor(1).uniform_(*mean_range).item()
63
        std = torch.FloatTensor(1).uniform_(*std_range).item()
64
        seed = get_random_seed()
65
        return mean, std, seed
66
67
68
class Noise(IntensityTransform):
69
    r"""Add Gaussian noise.
70
71
    Add noise sampled from a normal distribution.
72
73
    Args:
74
        mean: Mean :math:`\mu` of the Gaussian distribution
75
            from which the noise is sampled.
76
        std: Standard deviation :math:`\sigma` of the Gaussian distribution
77
            from which the noise is sampled.
78
        seed:
79
        keys: See :py:class:`~torchio.transforms.Transform`.
80
    """
81
    def __init__(
82
            self,
83
            mean: Union[float, Dict[str, float]],
84
            std: Union[float, Dict[str, float]],
85
            seed: Union[int, Sequence[int]],
86
            keys: Optional[List[str]] = None,
87
            ):
88
        super().__init__(keys=keys)
89
        self.mean = mean
90
        self.std = std
91
        self.seed = seed
92
        self._has_dicts = self.arguments_are_dict()
93
        self.invert_transform = False
94
        self.args_names = 'mean', 'std', 'seed'
95
96
    def arguments_are_dict(self):
97
        mean_dict = isinstance(self.mean, dict)
98
        std_dict = isinstance(self.std, dict)
99
        seed_dict = isinstance(self.seed, dict)
100
        three_bools = mean_dict, std_dict, seed_dict
101
        if all(three_bools):
102
            return True
103
        elif not any(three_bools):
104
            return False
105
        else:
106
            message = 'All arguments must have the same type: float or dict'
107
            raise ValueError(message)
108
109
    def apply_transform(self, subject: Subject) -> Subject:
110
        args = self.mean, self.std, self.seed
111
        for name, image in self.get_images_dict(subject).items():
112
            if self._has_dicts:
113
                args = self.mean[name], self.std[name], self.seed[name]
114
            noise = get_noise(image[DATA], *args)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
115
            if self.invert_transform:
116
                noise *= -1
117
            image[DATA] = image[DATA] + noise
118
        return subject
119
120
    def inverse(self):
121
        new = copy.deepcopy(self)
122
        new.invert_transform = not self.invert_transform
123
        return new
124
125
126
def get_randn(shape: Sequence[int], seed: int) -> torch.Tensor:
127
    torch_rng_state = torch.random.get_rng_state()
128
    torch.manual_seed(seed)
129
    noise = torch.randn(*shape)
130
    torch.random.set_rng_state(torch_rng_state)
131
    return noise
132
133
134
def get_noise(
135
        tensor: torch.Tensor,
136
        mean: float,
137
        std: float,
138
        seed: Optional[int] = None) -> torch.Tensor:
139
    return get_randn(tensor.shape, seed=seed) * std + mean
140