1
|
|
|
import torch |
2
|
|
|
|
3
|
|
|
import torchio as tio |
4
|
|
|
|
5
|
|
|
from ...utils import TorchioTestCase |
6
|
|
|
|
7
|
|
|
|
8
|
|
|
class TestCrop(TorchioTestCase): |
9
|
|
|
def test_tensor_single_channel(self): |
10
|
|
|
crop = tio.Crop(1) |
11
|
|
|
assert crop(torch.rand(1, 10, 10, 10)).shape == (1, 8, 8, 8) |
12
|
|
|
|
13
|
|
|
def test_tensor_multi_channel(self): |
14
|
|
|
crop = tio.Crop(1) |
15
|
|
|
assert crop(torch.rand(3, 10, 10, 10)).shape == (3, 8, 8, 8) |
16
|
|
|
|
17
|
|
|
def test_subject_copy(self): |
18
|
|
|
crop = tio.Crop(1, copy=True) |
19
|
|
|
subject = tio.Subject(t1=tio.ScalarImage(tensor=torch.rand(1, 10, 10, 10))) |
20
|
|
|
cropped_subject = crop(subject) |
21
|
|
|
assert cropped_subject.t1.shape == (1, 8, 8, 8) |
22
|
|
|
assert subject.t1.shape == (1, 10, 10, 10) |
23
|
|
|
|
24
|
|
|
cropped2_subject = crop(cropped_subject) |
25
|
|
|
assert cropped2_subject.t1.shape == (1, 6, 6, 6) |
26
|
|
|
assert cropped_subject.t1.shape == (1, 8, 8, 8) |
27
|
|
|
assert len(cropped2_subject.applied_transforms) == 2 |
28
|
|
|
assert len(cropped_subject.applied_transforms) == 1 |
29
|
|
|
|
30
|
|
|
def test_subject_no_copy(self): |
31
|
|
|
crop = tio.Crop(1, copy=False) |
32
|
|
|
subject = tio.Subject(t1=tio.ScalarImage(tensor=torch.rand(1, 10, 10, 10))) |
33
|
|
|
cropped_subject = crop(subject) |
34
|
|
|
assert cropped_subject.t1.shape == (1, 8, 8, 8) |
35
|
|
|
assert subject.t1.shape == (1, 8, 8, 8) |
36
|
|
|
|