Passed
Push — master ( 0aadcd...59a0a3 )
by Fernando
01:13
created

TestRandomFlip.test_anatomical_axis()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 7
nop 1
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
import torch
2
from torchio import RandomFlip
3
from ...utils import TorchioTestCase
4
5
6
class TestRandomFlip(TorchioTestCase):
7
    """Tests for `RandomFlip`."""
8
    def test_2d(self):
9
        sample = self.make_2d(self.sample)
10
        transform = RandomFlip(axes=(1, 2), flip_probability=1)
11
        transformed = transform(sample)
12
        self.assertTensorEqual(
13
            sample.t1.data.numpy()[..., ::-1, ::-1],
14
            transformed.t1.data.numpy(),
15
        )
16
17
    def test_out_of_range_axis(self):
18
        with self.assertRaises(ValueError):
19
            RandomFlip(axes=3)
20
21
    def test_out_of_range_axis_in_tuple(self):
22
        with self.assertRaises(ValueError):
23
            RandomFlip(axes=(0, -1, 2))
24
25
    def test_wrong_axes_type(self):
26
        with self.assertRaises(ValueError):
27
            RandomFlip(axes=None)
28
29
    def test_wrong_flip_probability_type(self):
30
        with self.assertRaises(ValueError):
31
            RandomFlip(flip_probability='wrong')
32
33
    def test_anatomical_axis(self):
34
        transform = RandomFlip(axes=['i'], flip_probability=1)
35
        tensor = torch.rand(1, 2, 3, 4)
36
        transformed = transform(tensor)
37
        self.assertTensorEqual(
38
            tensor.numpy()[..., ::-1],
39
            transformed.numpy(),
40
        )
41