|
1
|
|
|
import torch |
|
2
|
|
|
from torchio import RandomFlip |
|
3
|
|
|
from ...utils import TorchioTestCase |
|
4
|
|
|
|
|
5
|
|
|
|
|
6
|
|
|
class TestRandomFlip(TorchioTestCase): |
|
7
|
|
|
"""Tests for `RandomFlip`.""" |
|
8
|
|
|
def test_2d(self): |
|
9
|
|
|
sample = self.make_2d(self.sample) |
|
10
|
|
|
transform = RandomFlip(axes=(1, 2), flip_probability=1) |
|
11
|
|
|
transformed = transform(sample) |
|
12
|
|
|
self.assertTensorEqual( |
|
13
|
|
|
sample.t1.data.numpy()[..., ::-1, ::-1], |
|
14
|
|
|
transformed.t1.data.numpy(), |
|
15
|
|
|
) |
|
16
|
|
|
|
|
17
|
|
|
def test_out_of_range_axis(self): |
|
18
|
|
|
with self.assertRaises(ValueError): |
|
19
|
|
|
RandomFlip(axes=3) |
|
20
|
|
|
|
|
21
|
|
|
def test_out_of_range_axis_in_tuple(self): |
|
22
|
|
|
with self.assertRaises(ValueError): |
|
23
|
|
|
RandomFlip(axes=(0, -1, 2)) |
|
24
|
|
|
|
|
25
|
|
|
def test_wrong_axes_type(self): |
|
26
|
|
|
with self.assertRaises(ValueError): |
|
27
|
|
|
RandomFlip(axes=None) |
|
28
|
|
|
|
|
29
|
|
|
def test_wrong_flip_probability_type(self): |
|
30
|
|
|
with self.assertRaises(ValueError): |
|
31
|
|
|
RandomFlip(flip_probability='wrong') |
|
32
|
|
|
|
|
33
|
|
|
def test_anatomical_axis(self): |
|
34
|
|
|
transform = RandomFlip(axes=['i'], flip_probability=1) |
|
35
|
|
|
tensor = torch.rand(1, 2, 3, 4) |
|
36
|
|
|
transformed = transform(tensor) |
|
37
|
|
|
self.assertTensorEqual( |
|
38
|
|
|
tensor.numpy()[..., ::-1], |
|
39
|
|
|
transformed.numpy(), |
|
40
|
|
|
) |
|
41
|
|
|
|