|
1
|
|
|
import pytest |
|
2
|
|
|
import SimpleITK as sitk |
|
3
|
|
|
import torch |
|
4
|
|
|
|
|
5
|
|
|
import torchio as tio |
|
6
|
|
|
from torchio.data.io import sitk_to_nib |
|
7
|
|
|
|
|
8
|
|
|
from ...utils import TorchioTestCase |
|
9
|
|
|
|
|
10
|
|
|
|
|
11
|
|
|
class TestPad(TorchioTestCase): |
|
12
|
|
|
"""Tests for `Pad`.""" |
|
13
|
|
|
|
|
14
|
|
|
def test_pad(self): |
|
15
|
|
|
image = self.sample_subject.t1 |
|
16
|
|
|
padding = 1, 2, 3, 4, 5, 6 |
|
17
|
|
|
sitk_image = image.as_sitk() |
|
18
|
|
|
low, high = padding[::2], padding[1::2] |
|
19
|
|
|
sitk_padded = sitk.ConstantPad(sitk_image, low, high, 0) |
|
20
|
|
|
tio_padded = tio.Pad(padding, padding_mode=0)(image) |
|
21
|
|
|
sitk_tensor, sitk_affine = sitk_to_nib(sitk_padded) |
|
22
|
|
|
tio_tensor, tio_affine = sitk_to_nib(tio_padded.as_sitk()) |
|
23
|
|
|
self.assert_tensor_equal(sitk_tensor, tio_tensor) |
|
24
|
|
|
self.assert_tensor_equal(sitk_affine, tio_affine) |
|
25
|
|
|
|
|
26
|
|
|
def test_nans_history(self): |
|
27
|
|
|
padded = tio.Pad(1, padding_mode=2)(self.sample_subject) |
|
28
|
|
|
again = padded.history[0](self.sample_subject) |
|
29
|
|
|
assert not torch.isnan(again.t1.data).any() |
|
30
|
|
|
|
|
31
|
|
|
def test_padding_modes(self): |
|
32
|
|
|
def padding_func(): |
|
33
|
|
|
return |
|
34
|
|
|
|
|
35
|
|
|
for padding_mode in [0, *tio.Pad.PADDING_MODES, padding_func]: |
|
36
|
|
|
tio.Pad(0, padding_mode=padding_mode) |
|
37
|
|
|
|
|
38
|
|
|
with self.assertRaises(KeyError): |
|
39
|
|
|
tio.Pad(0, padding_mode='abc') |
|
40
|
|
|
|
|
41
|
|
|
def test_padding_mean_label_map(self): |
|
42
|
|
|
with self.assertWarns(RuntimeWarning): |
|
43
|
|
|
tio.Pad(1, padding_mode='mean')(self.sample_subject.label) |
|
44
|
|
|
|
|
45
|
|
|
def test_padding_modes_global(self): |
|
46
|
|
|
x = torch.ones(1, 1, 2, 2, dtype=torch.int) |
|
47
|
|
|
x[..., 0, 0] = 0 |
|
48
|
|
|
# The image should look like this: |
|
49
|
|
|
# 0 1 |
|
50
|
|
|
# 1 1 |
|
51
|
|
|
|
|
52
|
|
|
add_bottom_row = 0, 0, 0, 1, 0, 0 |
|
53
|
|
|
with_zeros = tio.Pad(add_bottom_row)(x) |
|
54
|
|
|
assert with_zeros[0, 0, 2].tolist() == [0, 0] |
|
55
|
|
|
|
|
56
|
|
|
with_minimum = tio.Pad(add_bottom_row, padding_mode='minimum')(x) |
|
57
|
|
|
assert with_minimum[0, 0, 2].tolist() == [0, 0] |
|
58
|
|
|
|
|
59
|
|
|
with_maximum = tio.Pad(add_bottom_row, padding_mode='maximum')(x) |
|
60
|
|
|
assert with_maximum[0, 0, 2].tolist() == [1, 1] |
|
61
|
|
|
|
|
62
|
|
|
with_median = tio.Pad(add_bottom_row, padding_mode='median')(x) |
|
63
|
|
|
assert with_median[0, 0, 2].tolist() == [1, 1] |
|
64
|
|
|
|
|
65
|
|
|
# This is a special case: as we instantiated the tensor with integers, |
|
66
|
|
|
# the mean (3/4) will be trucated to 0. |
|
67
|
|
|
with_mean = tio.Pad(add_bottom_row, padding_mode='mean')(x) |
|
68
|
|
|
assert with_mean[0, 0, 2].tolist() == [0, 0] |
|
69
|
|
|
# So let's test with floats too |
|
70
|
|
|
x = x.float() |
|
71
|
|
|
with_mean = tio.Pad(add_bottom_row, padding_mode='mean')(x) |
|
72
|
|
|
assert with_mean[0, 0, 2].tolist() == [0.75, 0.75] |
|
73
|
|
|
|
|
74
|
|
|
def test_truncation_warning(self): |
|
75
|
|
|
x = torch.ones(1, 1, 2, 2, dtype=torch.int) |
|
76
|
|
|
pad = tio.Pad(1, padding_mode='mean') |
|
77
|
|
|
with pytest.warns(RuntimeWarning): |
|
78
|
|
|
pad(x) |
|
79
|
|
|
|