Passed
Pull Request — master (#353)
by Fernando
01:16
created

RandomGamma.get_params()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import warnings
2
from collections import defaultdict
3
from typing import Tuple, Optional, Sequence
4
5
import torch
6
7
from ....utils import to_tuple
8
from ....torchio import DATA, TypeRangeFloat
9
from ....data.subject import Subject
10
from ... import IntensityTransform
11
from .. import RandomTransform
12
13
14
class RandomGamma(RandomTransform, IntensityTransform):
15
    r"""Randomly change contrast of an image by raising its values to the power
16
    :math:`\gamma`.
17
18
    Args:
19
        log_gamma: Tuple :math:`(a, b)` to compute the exponent
20
            :math:`\gamma = e ^ \beta`,
21
            where :math:`\beta \sim \mathcal{U}(a, b)`.
22
            If a single value :math:`d` is provided, then
23
            :math:`\beta \sim \mathcal{U}(-d, d)`.
24
            Negative and positive values for this argument perform gamma
25
            compression and expansion, respectively.
26
            See the `Gamma correction`_ Wikipedia entry for more information.
27
        p: Probability that this transform will be applied.
28
        keys: See :py:class:`~torchio.transforms.Transform`.
29
30
    .. _Gamma correction: https://en.wikipedia.org/wiki/Gamma_correction
31
32
    .. warning:: Fractional exponentiation of negative values is generally not
33
        well-defined for non-complex numbers.
34
        If negative values are found in the input image :math:`I`,
35
        the applied transform is :math:`\text{sign}(I) |I|^\gamma`,
36
        instead of the usual :math:`I^\gamma`. The
37
        :py:class:`~torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity`
38
        transform may be used to ensure that all values are positive.
39
40
    Example:
41
        >>> import torchio as tio
42
        >>> from torchio import RandomGamma
43
        >>> from tio.datasets import FPG
44
        >>> subject = FPG()
45
        >>> transform = RandomGamma(log_gamma=(-0.3, 0.3))  # gamma between 0.74 and 1.34
46
        >>> transformed = transform(subject)
47
    """
48
    def __init__(
49
            self,
50
            log_gamma: TypeRangeFloat = (-0.3, 0.3),
51
            p: float = 1,
52
            keys: Optional[Sequence[str]] = None,
53
            ):
54
        super().__init__(p=p, keys=keys)
55
        self.log_gamma_range = self.parse_range(log_gamma, 'log_gamma')
56
57
    def apply_transform(self, subject: Subject) -> Subject:
58
        arguments = defaultdict(dict)
59
        for name, image in self.get_images_dict(subject).items():
60
            gammas = [self.get_params(self.log_gamma_range) for _ in image.data]
61
            arguments['gamma'][name] = gammas
62
        transform = Gamma(**arguments)
63
        transformed = transform(subject)
64
        return transformed
65
66
    def get_params(self, log_gamma_range: Tuple[float, float]) -> float:
67
        gamma = self.sample_uniform(*log_gamma_range).exp().item()
68
        return gamma
69
70
71
class Gamma(IntensityTransform):
72
    r"""Change contrast of an image by raising its values to the power
73
    :math:`\gamma`.
74
75
    Args:
76
        gamma: Exponent to which values in the image will be raised.
77
            Negative and positive values for this argument perform gamma
78
            compression and expansion, respectively.
79
            See the `Gamma correction`_ Wikipedia entry for more information.
80
        keys: See :py:class:`~torchio.transforms.Transform`.
81
82
    .. _Gamma correction: https://en.wikipedia.org/wiki/Gamma_correction
83
84
    .. warning:: Fractional exponentiation of negative values is generally not
85
        well-defined for non-complex numbers.
86
        If negative values are found in the input image :math:`I`,
87
        the applied transform is :math:`\text{sign}(I) |I|^\gamma`,
88
        instead of the usual :math:`I^\gamma`. The
89
        :py:class:`~torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity`
90
        transform may be used to ensure that all values are positive.
91
92
    Example:
93
        >>> import torchio as tio
94
        >>> from torchio import Gamma
95
        >>> from tio.datasets import FPG
96
        >>> subject = FPG()
97
        >>> transform = Gamma(0.8)
98
        >>> transformed = transform(subject)
99
    """
100
    def __init__(
101
            self,
102
            gamma: float,
103
            keys: Optional[Sequence[str]] = None,
104
            ):
105
        super().__init__(keys=keys)
106
        self.gamma = gamma
107
        self.args_names = ('gamma',)
108
        self.invert_transform = False
109
110
    def apply_transform(self, subject: Subject) -> Subject:
111
        gamma = self.gamma
112
        for name, image in self.get_images_dict(subject).items():
113
            if self.arguments_are_dict():
114
                gamma = self.gamma[name]
115
            gammas = to_tuple(gamma, length=len(image.data))
116
            transformed_tensors = []
117
            for gamma, tensor in zip(gammas, image.data):
118
                if self.invert_transform:
119
                    correction = power(tensor, 1 - gamma)
120
                    transformed_tensor = tensor * correction
121
                else:
122
                    transformed_tensor = power(tensor, gamma)
123
                transformed_tensors.append(transformed_tensor)
124
            image[DATA] = torch.stack(transformed_tensors)
125
        return subject
126
127
128
def power(tensor, gamma):
129
    if tensor.min() < 0:
130
        message = (
131
            'Negative values found in input tensor. See the documentation for'
132
            ' more details on the implemented workaround:'
133
            ' https://torchio.readthedocs.io/transforms/augmentation.html#randomgamma'
134
        )
135
        warnings.warn(message)
136
        output = tensor.sign() * tensor.abs() ** gamma
137
    else:
138
        output = tensor ** gamma
139
    return output
140