Passed
Push — main ( 220e88...2d661a )
by Fernando
01:34
created

tests.transforms.preprocessing.test_crop   A

Complexity

Total Complexity 4

Size/Duplication

Total Lines 36
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 28
dl 0
loc 36
rs 10
c 0
b 0
f 0
wmc 4

4 Methods

Rating   Name   Duplication   Size   Complexity  
A TestCrop.test_tensor_multi_channel() 0 3 1
A TestCrop.test_tensor_single_channel() 0 3 1
A TestCrop.test_subject_no_copy() 0 6 1
A TestCrop.test_subject_copy() 0 12 1
1
import torch
2
3
import torchio as tio
4
5
from ...utils import TorchioTestCase
6
7
8
class TestCrop(TorchioTestCase):
9
    def test_tensor_single_channel(self):
10
        crop = tio.Crop(1)
11
        assert crop(torch.rand(1, 10, 10, 10)).shape == (1, 8, 8, 8)
12
13
    def test_tensor_multi_channel(self):
14
        crop = tio.Crop(1)
15
        assert crop(torch.rand(3, 10, 10, 10)).shape == (3, 8, 8, 8)
16
17
    def test_subject_copy(self):
18
        crop = tio.Crop(1, copy=True)
19
        subject = tio.Subject(t1=tio.ScalarImage(tensor=torch.rand(1, 10, 10, 10)))
20
        cropped_subject = crop(subject)
21
        assert cropped_subject.t1.shape == (1, 8, 8, 8)
22
        assert subject.t1.shape == (1, 10, 10, 10)
23
24
        cropped2_subject = crop(cropped_subject)
25
        assert cropped2_subject.t1.shape == (1, 6, 6, 6)
26
        assert cropped_subject.t1.shape == (1, 8, 8, 8)
27
        assert len(cropped2_subject.applied_transforms) == 2
28
        assert len(cropped_subject.applied_transforms) == 1
29
30
    def test_subject_no_copy(self):
31
        crop = tio.Crop(1, copy=False)
32
        subject = tio.Subject(t1=tio.ScalarImage(tensor=torch.rand(1, 10, 10, 10)))
33
        cropped_subject = crop(subject)
34
        assert cropped_subject.t1.shape == (1, 8, 8, 8)
35
        assert subject.t1.shape == (1, 8, 8, 8)
36