Passed
Push — main ( 6ea3bf...edcd90 )
by Fernando
01:38
created

TestPad.test_padding_modes()   A

Complexity

Conditions 3

Size

Total Lines 9
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 7
nop 1
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
import pytest
2
import SimpleITK as sitk
3
import torch
4
5
import torchio as tio
6
from torchio.data.io import sitk_to_nib
7
8
from ...utils import TorchioTestCase
9
10
11
class TestPad(TorchioTestCase):
12
    """Tests for `Pad`."""
13
14
    def test_pad(self):
15
        image = self.sample_subject.t1
16
        padding = 1, 2, 3, 4, 5, 6
17
        sitk_image = image.as_sitk()
18
        low, high = padding[::2], padding[1::2]
19
        sitk_padded = sitk.ConstantPad(sitk_image, low, high, 0)
20
        tio_padded = tio.Pad(padding, padding_mode=0)(image)
21
        sitk_tensor, sitk_affine = sitk_to_nib(sitk_padded)
22
        tio_tensor, tio_affine = sitk_to_nib(tio_padded.as_sitk())
23
        self.assert_tensor_equal(sitk_tensor, tio_tensor)
24
        self.assert_tensor_equal(sitk_affine, tio_affine)
25
26
    def test_nans_history(self):
27
        padded = tio.Pad(1, padding_mode=2)(self.sample_subject)
28
        again = padded.history[0](self.sample_subject)
29
        assert not torch.isnan(again.t1.data).any()
30
31
    def test_padding_modes(self):
32
        def padding_func():
33
            return
34
35
        for padding_mode in [0, *tio.Pad.PADDING_MODES, padding_func]:
36
            tio.Pad(0, padding_mode=padding_mode)
37
38
        with self.assertRaises(KeyError):
39
            tio.Pad(0, padding_mode='abc')
40
41
    def test_padding_mean_label_map(self):
42
        with self.assertWarns(RuntimeWarning):
43
            tio.Pad(1, padding_mode='mean')(self.sample_subject.label)
44
45
    def test_padding_modes_global(self):
46
        x = torch.ones(1, 1, 2, 2, dtype=torch.int)
47
        x[..., 0, 0] = 0
48
        # The image should look like this:
49
        # 0 1
50
        # 1 1
51
52
        add_bottom_row = 0, 0, 0, 1, 0, 0
53
        with_zeros = tio.Pad(add_bottom_row)(x)
54
        assert with_zeros[0, 0, 2].tolist() == [0, 0]
55
56
        with_minimum = tio.Pad(add_bottom_row, padding_mode='minimum')(x)
57
        assert with_minimum[0, 0, 2].tolist() == [0, 0]
58
59
        with_maximum = tio.Pad(add_bottom_row, padding_mode='maximum')(x)
60
        assert with_maximum[0, 0, 2].tolist() == [1, 1]
61
62
        with_median = tio.Pad(add_bottom_row, padding_mode='median')(x)
63
        assert with_median[0, 0, 2].tolist() == [1, 1]
64
65
        # This is a special case: as we instantiated the tensor with integers,
66
        # the mean (3/4) will be trucated to 0.
67
        with_mean = tio.Pad(add_bottom_row, padding_mode='mean')(x)
68
        assert with_mean[0, 0, 2].tolist() == [0, 0]
69
        # So let's test with floats too
70
        x = x.float()
71
        with_mean = tio.Pad(add_bottom_row, padding_mode='mean')(x)
72
        assert with_mean[0, 0, 2].tolist() == [0.75, 0.75]
73
74
    def test_truncation_warning(self):
75
        x = torch.ones(1, 1, 2, 2, dtype=torch.int)
76
        pad = tio.Pad(1, padding_mode='mean')
77
        with pytest.warns(RuntimeWarning):
78
            pad(x)
79