Passed
Pull Request — master (#353)
by Fernando
01:16
created

Gamma.apply_transform()   A

Complexity

Conditions 5

Size

Total Lines 16
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 15
nop 2
dl 0
loc 16
rs 9.1832
c 0
b 0
f 0
1
import warnings
2
from collections import defaultdict
3
from typing import Tuple, Optional, List
4
5
import torch
6
7
from ....utils import to_tuple
8
from ....torchio import DATA, TypeRangeFloat
9
from ....data.subject import Subject
10
from ... import IntensityTransform
11
from .. import RandomTransform
12
13
14
class RandomGamma(RandomTransform, IntensityTransform):
15
    r"""Randomly change contrast of an image by raising its values to the power
16
    :math:`\gamma`.
17
18
    Args:
19
        log_gamma: Tuple :math:`(a, b)` to compute the exponent
20
            :math:`\gamma = e ^ \beta`,
21
            where :math:`\beta \sim \mathcal{U}(a, b)`.
22
            If a single value :math:`d` is provided, then
23
            :math:`\beta \sim \mathcal{U}(-d, d)`.
24
            Negative and positive values for this argument perform gamma
25
            compression and expansion, respectively.
26
            See the `Gamma correction`_ Wikipedia entry for more information.
27
        p: Probability that this transform will be applied.
28
        keys: See :py:class:`~torchio.transforms.Transform`.
29
30
    .. _Gamma correction: https://en.wikipedia.org/wiki/Gamma_correction
31
32
    .. warning:: Fractional exponentiation of negative values is generally not
33
        well-defined for non-complex numbers.
34
        If negative values are found in the input image :math:`I`,
35
        the applied transform is :math:`\text{sign}(I) |I|^\gamma`,
36
        instead of the usual :math:`I^\gamma`. The
37
        :py:class:`~torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity`
38
        transform may be used to ensure that all values are positive.
39
40
    Example:
41
        >>> import torchio as tio
42
        >>> from torchio import RandomGamma
43
        >>> from tio.datasets import FPG
44
        >>> subject = FPG()
45
        >>> transform = RandomGamma(log_gamma=(-0.3, 0.3))  # gamma between 0.74 and 1.34
46
        >>> transformed = transform(subject)
47
    """
48
    def __init__(
49
            self,
50
            log_gamma: TypeRangeFloat = (-0.3, 0.3),
51
            p: float = 1,
52
            keys: Optional[List[str]] = None,
53
            ):
54
        super().__init__(p=p, keys=keys)
55
        self.log_gamma_range = self.parse_range(log_gamma, 'log_gamma')
56
57
    def apply_transform(self, subject: Subject) -> Subject:
58
        arguments = defaultdict(dict)
59
        for name, image in self.get_images_dict(subject).items():
60
            gammas = [self.get_params(self.log_gamma_range) for _ in image.data]
61
            arguments['gamma'][name] = gammas
62
        transform = Gamma(**arguments)
63
        transformed = transform(subject)
64
        return transformed
65
66
    @staticmethod
67
    def get_params(log_gamma_range: Tuple[float, float]) -> float:
68
        gamma = torch.FloatTensor(1).uniform_(*log_gamma_range).exp().item()
69
        return gamma
70
71
72
class Gamma(IntensityTransform):
73
    r"""Change contrast of an image by raising its values to the power
74
    :math:`\gamma`.
75
76
    Args:
77
        gamma: Exponent to which values in the image will be raised.
78
            Negative and positive values for this argument perform gamma
79
            compression and expansion, respectively.
80
            See the `Gamma correction`_ Wikipedia entry for more information.
81
        keys: See :py:class:`~torchio.transforms.Transform`.
82
83
    .. _Gamma correction: https://en.wikipedia.org/wiki/Gamma_correction
84
85
    .. warning:: Fractional exponentiation of negative values is generally not
86
        well-defined for non-complex numbers.
87
        If negative values are found in the input image :math:`I`,
88
        the applied transform is :math:`\text{sign}(I) |I|^\gamma`,
89
        instead of the usual :math:`I^\gamma`. The
90
        :py:class:`~torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity`
91
        transform may be used to ensure that all values are positive.
92
93
    Example:
94
        >>> import torchio as tio
95
        >>> from torchio import Gamma
96
        >>> from tio.datasets import FPG
97
        >>> subject = FPG()
98
        >>> transform = Gamma(0.8)
99
        >>> transformed = transform(subject)
100
    """
101
    def __init__(
102
            self,
103
            gamma: float,
104
            keys: Optional[List[str]] = None,
105
            ):
106
        super().__init__(keys=keys)
107
        self.gamma = gamma
108
        self.args_names = ('gamma',)
109
        self.invert_transform = False
110
111
    def apply_transform(self, subject: Subject) -> Subject:
112
        gamma = self.gamma
113
        for name, image in self.get_images_dict(subject).items():
114
            if self.arguments_are_dict():
115
                gamma = self.gamma[name]
116
            gammas = to_tuple(gamma, length=len(image.data))
117
            transformed_tensors = []
118
            for gamma, tensor in zip(gammas, image.data):
119
                if self.invert_transform:
120
                    correction = power(tensor, 1 - gamma)
121
                    transformed_tensor = tensor * correction
122
                else:
123
                    transformed_tensor = power(tensor, gamma)
124
                transformed_tensors.append(transformed_tensor)
125
            image[DATA] = torch.stack(transformed_tensors)
126
        return subject
127
128
129
def power(tensor, gamma):
130
    if tensor.min() < 0:
131
        message = (
132
            'Negative values found in input tensor. See the documentation for'
133
            ' more details on the implemented workaround:'
134
            ' https://torchio.readthedocs.io/transforms/augmentation.html#randomgamma'
135
        )
136
        warnings.warn(message)
137
        output = tensor.sign() * tensor.abs() ** gamma
138
    else:
139
        output = tensor ** gamma
140
    return output
141