Passed
Push — master ( bbbeef...45400a )
by Fernando
01:13
created

RandomGamma.__init__()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 8
nop 5
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
from typing import Tuple, Optional, List
2
import torch
3
from ....torchio import DATA, TypeRangeFloat
4
from ....data.subject import Subject
5
from .. import RandomTransform
6
7
8
class RandomGamma(RandomTransform):
9
    r"""Change contrast of an image by raising its values to the power
10
    :math:`\gamma`.
11
12
    Args:
13
        log_gamma: Tuple :math:`(a, b)` to compute the exponent
14
            :math:`\gamma = e ^ \beta`,
15
            where :math:`\beta \sim \mathcal{U}(a, b)`.
16
            If a single value :math:`d` is provided, then
17
            :math:`\beta \sim \mathcal{U}(-d, d)`.
18
            Negative and positive values for this argument perform gamma
19
            compression and expansion, respectively.
20
            See the `Gamma correction`_ Wikipedia entry for more information.
21
        p: Probability that this transform will be applied.
22
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
23
        keys: See :py:class:`~torchio.transforms.Transform`.
24
25
    .. _Gamma correction: https://en.wikipedia.org/wiki/Gamma_correction
26
27
    Example:
28
        >>> import torchio
29
        >>> from torchio import RandomGamma
30
        >>> from torchio.datasets import FPG
31
        >>> sample = FPG()
32
        >>> transform = RandomGamma(log_gamma=(-0.3, 0.3))  # gamma between 0.74 and 1.34
33
        >>> transformed = transform(sample)
34
    """
35
    def __init__(
36
            self,
37
            log_gamma: TypeRangeFloat = (-0.3, 0.3),
38
            p: float = 1,
39
            seed: Optional[int] = None,
40
            keys: Optional[List[str]] = None,
41
            ):
42
        super().__init__(p=p, seed=seed, keys=keys)
43
        self.log_gamma_range = self.parse_range(log_gamma, 'log_gamma')
44
45
    def apply_transform(self, sample: Subject) -> dict:
46
        random_parameters_images_dict = {}
47
        for image_name, image_dict in sample.get_images_dict().items():
48
            gamma = self.get_params(self.log_gamma_range)
49
            random_parameters_dict = {'gamma': gamma}
50
            random_parameters_images_dict[image_name] = random_parameters_dict
51
            image_dict[DATA] = image_dict[DATA] ** gamma
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
52
        sample.add_transform(self, random_parameters_images_dict)
53
        return sample
54
55
    @staticmethod
56
    def get_params(log_gamma_range: Tuple[float, float]) -> torch.Tensor:
57
        gamma = torch.FloatTensor(1).uniform_(*log_gamma_range).exp()
58
        return gamma
59