1
|
|
|
import pytest |
2
|
|
|
import SimpleITK as sitk |
3
|
|
|
import torch |
4
|
|
|
|
5
|
|
|
import torchio as tio |
6
|
|
|
from torchio.data.io import sitk_to_nib |
7
|
|
|
|
8
|
|
|
from ...utils import TorchioTestCase |
9
|
|
|
|
10
|
|
|
|
11
|
|
|
class TestPad(TorchioTestCase): |
12
|
|
|
"""Tests for `Pad`.""" |
13
|
|
|
|
14
|
|
|
def test_pad(self): |
15
|
|
|
image = self.sample_subject.t1 |
16
|
|
|
padding = 1, 2, 3, 4, 5, 6 |
17
|
|
|
sitk_image = image.as_sitk() |
18
|
|
|
low, high = padding[::2], padding[1::2] |
19
|
|
|
sitk_padded = sitk.ConstantPad(sitk_image, low, high, 0) |
20
|
|
|
tio_padded = tio.Pad(padding, padding_mode=0)(image) |
21
|
|
|
sitk_tensor, sitk_affine = sitk_to_nib(sitk_padded) |
22
|
|
|
tio_tensor, tio_affine = sitk_to_nib(tio_padded.as_sitk()) |
23
|
|
|
self.assert_tensor_equal(sitk_tensor, tio_tensor) |
24
|
|
|
self.assert_tensor_equal(sitk_affine, tio_affine) |
25
|
|
|
|
26
|
|
|
def test_nans_history(self): |
27
|
|
|
padded = tio.Pad(1, padding_mode=2)(self.sample_subject) |
28
|
|
|
again = padded.history[0](self.sample_subject) |
29
|
|
|
assert not torch.isnan(again.t1.data).any() |
30
|
|
|
|
31
|
|
|
def test_padding_modes(self): |
32
|
|
|
def padding_func(): |
33
|
|
|
return |
34
|
|
|
|
35
|
|
|
for padding_mode in [0, *tio.Pad.PADDING_MODES, padding_func]: |
36
|
|
|
tio.Pad(0, padding_mode=padding_mode) |
37
|
|
|
|
38
|
|
|
with self.assertRaises(KeyError): |
39
|
|
|
tio.Pad(0, padding_mode='abc') |
40
|
|
|
|
41
|
|
|
def test_padding_mean_label_map(self): |
42
|
|
|
with self.assertWarns(RuntimeWarning): |
43
|
|
|
tio.Pad(1, padding_mode='mean')(self.sample_subject.label) |
44
|
|
|
|
45
|
|
|
def test_padding_modes_global(self): |
46
|
|
|
x = torch.ones(1, 1, 2, 2, dtype=torch.int) |
47
|
|
|
x[..., 0, 0] = 0 |
48
|
|
|
# The image should look like this: |
49
|
|
|
# 0 1 |
50
|
|
|
# 1 1 |
51
|
|
|
|
52
|
|
|
add_bottom_row = 0, 0, 0, 1, 0, 0 |
53
|
|
|
with_zeros = tio.Pad(add_bottom_row)(x) |
54
|
|
|
assert with_zeros[0, 0, 2].tolist() == [0, 0] |
55
|
|
|
|
56
|
|
|
with_minimum = tio.Pad(add_bottom_row, padding_mode='minimum')(x) |
57
|
|
|
assert with_minimum[0, 0, 2].tolist() == [0, 0] |
58
|
|
|
|
59
|
|
|
with_maximum = tio.Pad(add_bottom_row, padding_mode='maximum')(x) |
60
|
|
|
assert with_maximum[0, 0, 2].tolist() == [1, 1] |
61
|
|
|
|
62
|
|
|
with_median = tio.Pad(add_bottom_row, padding_mode='median')(x) |
63
|
|
|
assert with_median[0, 0, 2].tolist() == [1, 1] |
64
|
|
|
|
65
|
|
|
# This is a special case: as we instantiated the tensor with integers, |
66
|
|
|
# the mean (3/4) will be trucated to 0. |
67
|
|
|
with_mean = tio.Pad(add_bottom_row, padding_mode='mean')(x) |
68
|
|
|
assert with_mean[0, 0, 2].tolist() == [0, 0] |
69
|
|
|
# So let's test with floats too |
70
|
|
|
x = x.float() |
71
|
|
|
with_mean = tio.Pad(add_bottom_row, padding_mode='mean')(x) |
72
|
|
|
assert with_mean[0, 0, 2].tolist() == [0.75, 0.75] |
73
|
|
|
|
74
|
|
|
def test_truncation_warning(self): |
75
|
|
|
x = torch.ones(1, 1, 2, 2, dtype=torch.int) |
76
|
|
|
pad = tio.Pad(1, padding_mode='mean') |
77
|
|
|
with pytest.warns(RuntimeWarning): |
78
|
|
|
pad(x) |
79
|
|
|
|