Passed
Push — master ( ba751d...d677ae )
by Fernando
01:28
created

TestRandomGamma.test_negative_values()   A

Complexity

Conditions 2

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import torch
2
from torchio import RandomGamma
3
from ...utils import TorchioTestCase
4
5
6
class TestRandomGamma(TorchioTestCase):
7
    """Tests for `RandomGamma`."""
8
    def test_with_zero_gamma(self):
9
        transform = RandomGamma(log_gamma=0)
10
        transformed = transform(self.sample)
11
        self.assertTensorAlmostEqual(self.sample.t1.data, transformed.t1.data)
12
13
    def test_with_non_zero_gamma(self):
14
        transform = RandomGamma(log_gamma=(0.1, 0.3))
15
        transformed = transform(self.sample)
16
        self.assertTensorNotEqual(self.sample.t1.data, transformed.t1.data)
17
18
    def test_with_high_gamma(self):
19
        transform = RandomGamma(log_gamma=(100, 100))
20
        transformed = transform(self.sample)
21
        self.assertTensorAlmostEqual(
22
            self.sample.t1.data == 1, transformed.t1.data
23
        )
24
25
    def test_with_low_gamma(self):
26
        transform = RandomGamma(log_gamma=(-100, -100))
27
        transformed = transform(self.sample)
28
        self.assertTensorAlmostEqual(
29
            self.sample.t1.data > 0, transformed.t1.data
30
        )
31
32
    def test_wrong_gamma_type(self):
33
        with self.assertRaises(ValueError):
34
            RandomGamma(log_gamma='wrong')
35
36
    def test_negative_values(self):
37
        with self.assertWarns(UserWarning):
38
            RandomGamma()(torch.rand(1, 3, 3, 3) - 1)
39