Passed
Push — master ( db45b7...6deb01 )
by Fernando
01:30
created

TestCropOrPad.test_no_target()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
import numpy as np
2
import torchio as tio
3
from ...utils import TorchioTestCase
4
5
6
class TestCropOrPad(TorchioTestCase):
7
    """Tests for `CropOrPad`."""
8
    def test_no_changes(self):
9
        sample_t1 = self.sample_subject['t1']
10
        shape = sample_t1.spatial_shape
11
        transform = tio.CropOrPad(shape)
12
        transformed = transform(self.sample_subject)
13
        self.assertTensorEqual(sample_t1.data, transformed['t1'].data)
14
        self.assertTensorEqual(sample_t1.affine, transformed['t1'].affine)
15
16
    def test_no_changes_mask(self):
17
        sample_t1 = self.sample_subject['t1']
18
        sample_mask = self.sample_subject['label'].data
19
        sample_mask *= 0
20
        shape = sample_t1.spatial_shape
21
        transform = tio.CropOrPad(shape, mask_name='label')
22
        with self.assertWarns(RuntimeWarning):
23
            transformed = transform(self.sample_subject)
24
        for key in transformed:
25
            image = self.sample_subject[key]
26
            self.assertTensorEqual(image.data, transformed[key].data)
27
            self.assertTensorEqual(image.affine, transformed[key].affine)
28
29
    def test_different_shape(self):
30
        shape = self.sample_subject['t1'].spatial_shape
31
        target_shape = 9, 21, 30
32
        transform = tio.CropOrPad(target_shape)
33
        transformed = transform(self.sample_subject)
34
        for key in transformed:
35
            result_shape = transformed[key].spatial_shape
36
            self.assertNotEqual(shape, result_shape)
37
38
    def test_shape_right(self):
39
        target_shape = 9, 21, 30
40
        transform = tio.CropOrPad(target_shape)
41
        transformed = transform(self.sample_subject)
42
        for key in transformed:
43
            result_shape = transformed[key].spatial_shape
44
            self.assertEqual(target_shape, result_shape)
45
46
    def test_only_pad(self):
47
        target_shape = 11, 22, 30
48
        transform = tio.CropOrPad(target_shape)
49
        transformed = transform(self.sample_subject)
50
        for key in transformed:
51
            result_shape = transformed[key].spatial_shape
52
            self.assertEqual(target_shape, result_shape)
53
54
    def test_only_crop(self):
55
        target_shape = 9, 18, 30
56
        transform = tio.CropOrPad(target_shape)
57
        transformed = transform(self.sample_subject)
58
        for key in transformed:
59
            result_shape = transformed[key].spatial_shape
60
            self.assertEqual(target_shape, result_shape)
61
62
    def test_shape_negative(self):
63
        with self.assertRaises(ValueError):
64
            tio.CropOrPad(-1)
65
66
    def test_shape_float(self):
67
        with self.assertRaises(ValueError):
68
            tio.CropOrPad(2.5)
69
70
    def test_shape_string(self):
71
        with self.assertRaises(ValueError):
72
            tio.CropOrPad('')
73
74
    def test_shape_one(self):
75
        transform = tio.CropOrPad(1)
76
        transformed = transform(self.sample_subject)
77
        for key in transformed:
78
            result_shape = transformed[key].spatial_shape
79
            self.assertEqual((1, 1, 1), result_shape)
80
81
    def test_wrong_mask_name(self):
82
        cop = tio.CropOrPad(1, mask_name='wrong')
83
        with self.assertWarns(RuntimeWarning):
84
            cop(self.sample_subject)
85
86
    def test_empty_mask(self):
87
        target_shape = 8, 22, 30
88
        transform = tio.CropOrPad(target_shape, mask_name='label')
89
        mask = self.sample_subject['label'].data
90
        mask *= 0
91
        with self.assertWarns(RuntimeWarning):
92
            transform(self.sample_subject)
93
94
    def mask_only(self, target_shape):
95
        transform = tio.CropOrPad(target_shape, mask_name='label')
96
        mask = self.sample_subject['label'].data
97
        mask *= 0
98
        mask[0, 4:6, 5:8, 3:7] = 1
99
        transformed = transform(self.sample_subject)
100
        shapes = []
101
        for key in transformed:
102
            result_shape = transformed[key].spatial_shape
103
            shapes.append(result_shape)
104
        set_shapes = set(shapes)
105
        message = f'Images have different shapes: {set_shapes}'
106
        assert len(set_shapes) == 1, message
107
        for key in transformed:
108
            result_shape = transformed[key].spatial_shape
109
            self.assertEqual(
110
                target_shape, result_shape,
111
                f'Wrong shape for image: {key}',
112
            )
113
114
    def test_mask_only_pad(self):
115
        self.mask_only((11, 22, 30))
116
117
    def test_mask_only_crop(self):
118
        self.mask_only((9, 18, 30))
119
120
    def test_center_mask(self):
121
        """The mask bounding box and the input image have the same center"""
122
        target_shape = 8, 22, 30
123
        transform_center = tio.CropOrPad(target_shape)
124
        transform_mask = tio.CropOrPad(target_shape, mask_name='label')
125
        mask = self.sample_subject['label'].data
126
        mask *= 0
127
        mask[0, 4:6, 9:11, 14:16] = 1
128
        transformed_center = transform_center(self.sample_subject)
129
        transformed_mask = transform_mask(self.sample_subject)
130
        zipped = zip(transformed_center.values(), transformed_mask.values())
131
        for image_center, image_mask in zipped:
132
            self.assertTensorEqual(
133
                image_center.data, image_mask.data,
134
                'Data is different after cropping',
135
            )
136
            self.assertTensorEqual(
137
                image_center.affine, image_mask.affine,
138
                'Physical position is different after cropping',
139
            )
140
141
    def test_mask_corners(self):
142
        """The mask bounding box and the input image have the same center"""
143
        target_shape = 8, 22, 30
144
        transform_center = tio.CropOrPad(target_shape)
145
        transform_mask = tio.CropOrPad(
146
            target_shape, mask_name='label')
147
        mask = self.sample_subject['label'].data
148
        mask *= 0
149
        mask[0, 0, 0, 0] = 1
150
        mask[0, -1, -1, -1] = 1
151
        transformed_center = transform_center(self.sample_subject)
152
        transformed_mask = transform_mask(self.sample_subject)
153
        zipped = zip(transformed_center.values(), transformed_mask.values())
154
        for image_center, image_mask in zipped:
155
            self.assertTensorEqual(
156
                image_center.data, image_mask.data,
157
                'Data is different after cropping',
158
            )
159
            self.assertTensorEqual(
160
                image_center.affine, image_mask.affine,
161
                'Physical position is different after cropping',
162
            )
163
164
    def test_2d(self):
165
        # https://github.com/fepegar/torchio/issues/434
166
        image = np.random.rand(1, 16, 16, 1)
167
        mask = np.zeros_like(image, dtype=bool)
168
        mask[0, 7, 0] = True
169
        subject = tio.Subject(
170
            image=tio.ScalarImage(tensor=image),
171
            mask=tio.LabelMap(tensor=mask),
172
        )
173
        transform = tio.CropOrPad((12, 12, 1), mask_name='mask')
174
        transformed = transform(subject)
175
        assert transformed.shape == (1, 12, 12, 1)
176
177
    def test_no_target_no_mask(self):
178
        with self.assertRaises(ValueError):
179
            tio.CropOrPad()
180
181
    def test_labels_but_no_mask(self):
182
        with self.assertRaises(ValueError):
183
            tio.CropOrPad(target_shape=(3, 4, 5), labels=[2, 3])
184
185
    def test_no_target(self):
186
        crop_with_mask = tio.CropOrPad(mask_name='label')
187
        crop_with_mask(self.sample_subject)
188