Passed
Push — master ( 79d509...08948f )
by Fernando
01:20
created

TestCropOrPad.test_empty_mask()   A

Complexity

Conditions 2

Size

Total Lines 7
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 1
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
import numpy as np
2
from torchio.transforms import CropOrPad
3
from torchio import DATA, AFFINE
4
from ...utils import TorchioTestCase
5
6
7
class TestCropOrPad(TorchioTestCase):
8
    """Tests for `CropOrPad`."""
9
    def test_no_changes(self):
10
        sample_t1 = self.sample_subject['t1']
11
        shape = sample_t1.spatial_shape
12
        transform = CropOrPad(shape)
13
        transformed = transform(self.sample_subject)
14
        self.assertTensorEqual(sample_t1[DATA], transformed['t1'][DATA])
15
        self.assertTensorEqual(sample_t1[AFFINE], transformed['t1'][AFFINE])
16
17
    def test_no_changes_mask(self):
18
        sample_t1 = self.sample_subject['t1']
19
        sample_mask = self.sample_subject['label'][DATA]
20
        sample_mask *= 0
21
        shape = sample_t1.spatial_shape
22
        transform = CropOrPad(shape, mask_name='label')
23
        with self.assertWarns(UserWarning):
24
            transformed = transform(self.sample_subject)
25
        for key in transformed:
26
            image_dict = self.sample_subject[key]
27
            self.assertTensorEqual(image_dict[DATA], transformed[key][DATA])
28
            self.assertTensorEqual(image_dict[AFFINE], transformed[key][AFFINE])
29
30
    def test_different_shape(self):
31
        shape = self.sample_subject['t1'].spatial_shape
32
        target_shape = 9, 21, 30
33
        transform = CropOrPad(target_shape)
34
        transformed = transform(self.sample_subject)
35
        for key in transformed:
36
            result_shape = transformed[key].spatial_shape
37
            self.assertNotEqual(shape, result_shape)
38
39
    def test_shape_right(self):
40
        target_shape = 9, 21, 30
41
        transform = CropOrPad(target_shape)
42
        transformed = transform(self.sample_subject)
43
        for key in transformed:
44
            result_shape = transformed[key].spatial_shape
45
            self.assertEqual(target_shape, result_shape)
46
47
    def test_only_pad(self):
48
        target_shape = 11, 22, 30
49
        transform = CropOrPad(target_shape)
50
        transformed = transform(self.sample_subject)
51
        for key in transformed:
52
            result_shape = transformed[key].spatial_shape
53
            self.assertEqual(target_shape, result_shape)
54
55
    def test_only_crop(self):
56
        target_shape = 9, 18, 30
57
        transform = CropOrPad(target_shape)
58
        transformed = transform(self.sample_subject)
59
        for key in transformed:
60
            result_shape = transformed[key].spatial_shape
61
            self.assertEqual(target_shape, result_shape)
62
63
    def test_shape_negative(self):
64
        with self.assertRaises(ValueError):
65
            CropOrPad(-1)
66
67
    def test_shape_float(self):
68
        with self.assertRaises(ValueError):
69
            CropOrPad(2.5)
70
71
    def test_shape_string(self):
72
        with self.assertRaises(ValueError):
73
            CropOrPad('')
74
75
    def test_shape_one(self):
76
        transform = CropOrPad(1)
77
        transformed = transform(self.sample_subject)
78
        for key in transformed:
79
            result_shape = transformed[key].spatial_shape
80
            self.assertEqual((1, 1, 1), result_shape)
81
82
    def test_wrong_mask_name(self):
83
        cop = CropOrPad(1, mask_name='wrong')
84
        with self.assertWarns(UserWarning):
85
            cop(self.sample_subject)
86
87
    def test_empty_mask(self):
88
        target_shape = 8, 22, 30
89
        transform = CropOrPad(target_shape, mask_name='label')
90
        mask = self.sample_subject['label'][DATA]
91
        mask *= 0
92
        with self.assertWarns(UserWarning):
93
            transform(self.sample_subject)
94
95 View Code Duplication
    def test_mask_only_pad(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
96
        target_shape = 11, 22, 30
97
        transform = CropOrPad(target_shape, mask_name='label')
98
        mask = self.sample_subject['label'][DATA]
99
        mask *= 0
100
        mask[0, 4:6, 5:8, 3:7] = 1
101
        transformed = transform(self.sample_subject)
102
        shapes = []
103
        for key in transformed:
104
            result_shape = transformed[key].spatial_shape
105
            shapes.append(result_shape)
106
        set_shapes = set(shapes)
107
        message = f'Images have different shapes: {set_shapes}'
108
        assert len(set_shapes) == 1, message
109
        for key in transformed:
110
            result_shape = transformed[key].spatial_shape
111
            self.assertEqual(
112
                target_shape, result_shape,
113
                f'Wrong shape for image: {key}',
114
            )
115
116 View Code Duplication
    def test_mask_only_crop(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
117
        target_shape = 9, 18, 30
118
        transform = CropOrPad(target_shape, mask_name='label')
119
        mask = self.sample_subject['label'][DATA]
120
        mask *= 0
121
        mask[0, 4:6, 5:8, 3:7] = 1
122
        transformed = transform(self.sample_subject)
123
        shapes = []
124
        for key in transformed:
125
            result_shape = transformed[key].spatial_shape
126
            shapes.append(result_shape)
127
        set_shapes = set(shapes)
128
        message = f'Images have different shapes: {set_shapes}'
129
        assert len(set_shapes) == 1, message
130
        for key in transformed:
131
            result_shape = transformed[key].spatial_shape
132
            self.assertEqual(
133
                target_shape, result_shape,
134
                f'Wrong shape for image: {key}',
135
            )
136
137
    def test_center_mask(self):
138
        """The mask bounding box and the input image have the same center"""
139
        target_shape = 8, 22, 30
140
        transform_center = CropOrPad(target_shape)
141
        transform_mask = CropOrPad(target_shape, mask_name='label')
142
        mask = self.sample_subject['label'][DATA]
143
        mask *= 0
144
        mask[0, 4:6, 9:11, 14:16] = 1
145
        transformed_center = transform_center(self.sample_subject)
146
        transformed_mask = transform_mask(self.sample_subject)
147
        zipped = zip(transformed_center.values(), transformed_mask.values())
148
        for image_center, image_mask in zipped:
149
            self.assertTensorEqual(
150
                image_center[DATA], image_mask[DATA],
151
                'Data is different after cropping',
152
            )
153
            self.assertTensorEqual(
154
                image_center[AFFINE], image_mask[AFFINE],
155
                'Physical position is different after cropping',
156
            )
157
158
    def test_mask_corners(self):
159
        """The mask bounding box and the input image have the same center"""
160
        target_shape = 8, 22, 30
161
        transform_center = CropOrPad(target_shape)
162
        transform_mask = CropOrPad(
163
            target_shape, mask_name='label')
164
        mask = self.sample_subject['label'][DATA]
165
        mask *= 0
166
        mask[0, 0, 0, 0] = 1
167
        mask[0, -1, -1, -1] = 1
168
        transformed_center = transform_center(self.sample_subject)
169
        transformed_mask = transform_mask(self.sample_subject)
170
        zipped = zip(transformed_center.values(), transformed_mask.values())
171
        for image_center, image_mask in zipped:
172
            self.assertTensorEqual(
173
                image_center[DATA], image_mask[DATA],
174
                'Data is different after cropping',
175
            )
176
            self.assertTensorEqual(
177
                image_center[AFFINE], image_mask[AFFINE],
178
                'Physical position is different after cropping',
179
            )
180
181
    def test_mask_origin(self):
182
        target_shape = 7, 21, 29
183
        center_voxel = np.floor(np.array(target_shape) / 2).astype(int)
184
        transform_center = CropOrPad(target_shape)
185
        transform_mask = CropOrPad(
186
            target_shape, mask_name='label')
187
        mask = self.sample_subject['label'][DATA]
188
        mask *= 0
189
        mask[0, 0, 0, 0] = 1
190
        transformed_center = transform_center(self.sample_subject)
191
        transformed_mask = transform_mask(self.sample_subject)
192
        zipped = zip(transformed_center.values(), transformed_mask.values())
193
        for image_center, image_mask in zipped:
194
            # Arrays are different
195
            self.assertTensorNotEqual(image_center[DATA], image_mask[DATA])
196
            # Rotation matrix doesn't change
197
            center_rotation = image_center[AFFINE][:3, :3]
198
            mask_rotation = image_mask[AFFINE][:3, :3]
199
            self.assertTensorEqual(center_rotation, mask_rotation)
200
            # Origin does change
201
            center_origin = image_center[AFFINE][:3, 3]
202
            mask_origin = image_mask[AFFINE][:3, 3]
203
            self.assertTensorNotEqual(center_origin, mask_origin)
204
            # Voxel at origin is center of transformed image
205
            origin_value = image_center[DATA][0, 0, 0, 0]
206
            i, j, k = center_voxel
207
            transformed_value = image_mask[DATA][0, i, j, k]
208
            self.assertEqual(origin_value, transformed_value)
209