Passed
Pull Request — master (#418)
by Fernando
01:25
created

TestInvertibility.test_different_interpolation()   A

Complexity

Conditions 1

Size

Total Lines 37
Code Lines 31

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 31
nop 1
dl 0
loc 37
rs 9.1359
c 0
b 0
f 0
1
import copy
2
import warnings
3
4
import torch
5
import torchio as tio
6
from torchio.transforms.intensity_transform import IntensityTransform
7
from ..utils import TorchioTestCase
8
9
10
class TestInvertibility(TorchioTestCase):
11
12
    def test_all_random_transforms(self):
13
        transform = self.get_large_composed_transform()
14
        # Remove RandomLabelsToImage as it will add a new image to the subject
15
        for t in transform.transforms:
16
            if t.name == 'RandomLabelsToImage':
17
                transform.transforms.remove(t)
18
                break
19
        # Ignore elastic deformation and gamma warnings during execution
20
        # Ignore some transforms not invertible
21
        with warnings.catch_warnings():
22
            warnings.simplefilter('ignore', RuntimeWarning)
23
            transformed = transform(self.sample_subject)
24
            inverting_transform = transformed.get_inverse_transform()
25
            transformed_back = inverting_transform(transformed)
26
        self.assertEqual(
27
            transformed.t1.shape,
28
            transformed_back.t1.shape,
29
        )
30
        self.assertTensorEqual(
31
            transformed.label.affine,
32
            transformed_back.label.affine,
33
        )
34
35
    def test_ignore_intensity(self):
36
        composed = self.get_large_composed_transform()
37
        with warnings.catch_warnings():
38
            warnings.simplefilter('ignore', RuntimeWarning)
39
            transformed = composed(self.sample_subject)
40
        inverse_transform = transformed.get_inverse_transform(warn=False)
41
        for transform in inverse_transform:
42
            assert not isinstance(transform, IntensityTransform)
43
44
    def test_different_interpolation(self):
45
        def model_probs(subject):
46
            subject = copy.deepcopy(subject)
47
            subject.im.set_data(torch.rand_like(subject.im.data))
48
            return subject
49
50
        def model_label(subject):
51
            subject = model_probs(subject)
52
            subject.im.set_data(torch.bernoulli(subject.im.data))
53
            return subject
54
55
        transform = tio.RandomAffine(image_interpolation='bspline')
56
        subject = copy.deepcopy(self.sample_subject)
57
        tensor = (torch.rand(1, 20, 20, 20) > 0.5).float()  # 0s and 1s
58
        subject = tio.Subject(im=tio.ScalarImage(tensor=tensor))
59
        transformed = transform(subject)
60
        assert transformed.im.data.min() < 0
61
        assert transformed.im.data.max() > 1
62
63
        subject_probs = model_probs(transformed)
64
        transformed_back = subject_probs.apply_inverse_transform()
65
        assert transformed_back.im.data.min() < 0
66
        assert transformed_back.im.data.max() > 1
67
        transformed_back_linear = subject_probs.apply_inverse_transform(
68
            image_interpolation='linear',
69
        )
70
        assert transformed_back_linear.im.data.min() >= 0
71
        assert transformed_back_linear.im.data.max() <= 1
72
73
        subject_label = model_label(transformed)
74
        transformed_back = subject_label.apply_inverse_transform()
75
        assert transformed_back.im.data.min() < 0
76
        assert transformed_back.im.data.max() > 1
77
        transformed_back_linear = subject_label.apply_inverse_transform(
78
            image_interpolation='nearest',
79
        )
80
        assert transformed_back_linear.im.data.unique().tolist() == [0, 1]
81