Passed
Push — master ( 7b848f...cac223 )
by Fernando
02:39
created

TestCropOrPad.test_only_crop()   A

Complexity

Conditions 2

Size

Total Lines 7
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 1
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
import warnings
2
import numpy as np
3
from numpy.testing import assert_array_equal
4
from torchio.transforms import CropOrPad, CenterCropOrPad
5
from torchio import DATA, AFFINE
6
from ...utils import TorchioTestCase
7
8
9
class TestCropOrPad(TorchioTestCase):
10
    """Tests for `CropOrPad`."""
11
    def test_no_changes(self):
12
        sample_t1 = self.sample['t1']
13
        shape = sample_t1[DATA].shape[1:]
14
        transform = CropOrPad(shape)
15
        transformed = transform(self.sample)
16
        assert_array_equal(sample_t1[DATA], transformed['t1'][DATA])
17
        assert_array_equal(sample_t1[AFFINE], transformed['t1'][AFFINE])
18
19
    def test_no_changes_mask(self):
20
        sample_t1 = self.sample['t1']
21
        sample_mask = self.sample['label'][DATA]
22
        sample_mask *= 0
23
        shape = sample_t1[DATA].shape[1:]
24
        transform = CropOrPad(shape, mask_name='label')
25
        with self.assertWarns(UserWarning):
26
            transformed = transform(self.sample)
27
        for key in transformed:
28
            image_dict = self.sample[key]
29
            assert_array_equal(image_dict[DATA], transformed[key][DATA])
30
            assert_array_equal(image_dict[AFFINE], transformed[key][AFFINE])
31
32
    def test_different_shape(self):
33
        shape = self.sample['t1'][DATA].shape[1:]
34
        target_shape = 9, 21, 30
35
        transform = CropOrPad(target_shape)
36
        transformed = transform(self.sample)
37
        for key in transformed:
38
            result_shape = transformed[key][DATA].shape[1:]
39
            self.assertNotEqual(shape, result_shape)
40
41
    def test_shape_right(self):
42
        target_shape = 9, 21, 30
43
        transform = CropOrPad(target_shape)
44
        transformed = transform(self.sample)
45
        for key in transformed:
46
            result_shape = transformed[key][DATA].shape[1:]
47
            self.assertEqual(target_shape, result_shape)
48
49
    def test_only_pad(self):
50
        target_shape = 11, 22, 30
51
        transform = CropOrPad(target_shape)
52
        transformed = transform(self.sample)
53
        for key in transformed:
54
            result_shape = transformed[key][DATA].shape[1:]
55
            self.assertEqual(target_shape, result_shape)
56
57
    def test_only_crop(self):
58
        target_shape = 9, 18, 30
59
        transform = CropOrPad(target_shape)
60
        transformed = transform(self.sample)
61
        for key in transformed:
62
            result_shape = transformed[key][DATA].shape[1:]
63
            self.assertEqual(target_shape, result_shape)
64
65
    def test_shape_negative(self):
66
        with self.assertRaises(ValueError):
67
            CropOrPad(-1)
68
69
    def test_shape_float(self):
70
        with self.assertRaises(ValueError):
71
            CropOrPad(2.5)
72
73
    def test_shape_string(self):
74
        with self.assertRaises(ValueError):
75
            CropOrPad('')
76
77
    def test_shape_one(self):
78
        transform = CropOrPad(1)
79
        transformed = transform(self.sample)
80
        for key in transformed:
81
            result_shape = transformed[key][DATA].shape[1:]
82
            self.assertEqual((1, 1, 1), result_shape)
83
84
    def test_wrong_mask_name(self):
85
        cop = CropOrPad(1, mask_name='wrong')
86
        with self.assertWarns(UserWarning):
87
            cop(self.sample)
88
89
    def test_deprecation(self):
90
        with self.assertWarns(DeprecationWarning):
91
            CenterCropOrPad(1)
92
93
    def test_empty_mask(self):
94
        target_shape = 8, 22, 30
95
        transform = CropOrPad(target_shape, mask_name='label')
96
        mask = self.sample['label'][DATA]
97
        mask *= 0
98
        with self.assertWarns(UserWarning):
99
            transform(self.sample)
100
101 View Code Duplication
    def test_mask_only_pad(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
102
        target_shape = 11, 22, 30
103
        transform = CropOrPad(target_shape, mask_name='label')
104
        mask = self.sample['label'][DATA]
105
        mask *= 0
106
        mask [0, 4:6, 5:8, 3:7] = 1
107
        transformed = transform(self.sample)
108
        shapes = []
109
        for key in transformed:
110
            result_shape = transformed[key][DATA].shape[1:]
111
            shapes.append(result_shape)
112
        set_shapes = set(shapes)
113
        message = f'Images have different shapes: {set_shapes}'
114
        assert len(set_shapes) == 1, message
115
        for key in transformed:
116
            result_shape = transformed[key][DATA].shape[1:]
117
            self.assertEqual(target_shape, result_shape,
118
                f'Wrong shape for image: {key}',
119
            )
120
121 View Code Duplication
    def test_mask_only_crop(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
122
        target_shape = 9, 18, 30
123
        transform = CropOrPad(target_shape, mask_name='label')
124
        mask = self.sample['label'][DATA]
125
        mask *= 0
126
        mask [0, 4:6, 5:8, 3:7] = 1
127
        transformed = transform(self.sample)
128
        shapes = []
129
        for key in transformed:
130
            result_shape = transformed[key][DATA].shape[1:]
131
            shapes.append(result_shape)
132
        set_shapes = set(shapes)
133
        message = f'Images have different shapes: {set_shapes}'
134
        assert len(set_shapes) == 1, message
135
        for key in transformed:
136
            result_shape = transformed[key][DATA].shape[1:]
137
            self.assertEqual(target_shape, result_shape,
138
                f'Wrong shape for image: {key}',
139
            )
140
141
    def test_center_mask(self):
142
        """The mask bounding box and the input image have the same center"""
143
        target_shape = 8, 22, 30
144
        transform_center = CropOrPad(target_shape)
145
        transform_mask = CropOrPad(target_shape, mask_name='label')
146
        mask = self.sample['label'][DATA]
147
        mask *= 0
148
        mask[0, 4:6, 9:11, 14:16] = 1
149
        transformed_center = transform_center(self.sample)
150
        transformed_mask = transform_mask(self.sample)
151
        zipped = zip(transformed_center.values(), transformed_mask.values())
152
        for image_center, image_mask in zipped:
153
            assert_array_equal(
154
                image_center[DATA], image_mask[DATA],
155
                'Data is different after cropping',
156
            )
157
            assert_array_equal(
158
                image_center[AFFINE], image_mask[AFFINE],
159
                'Physical position is different after cropping',
160
            )
161
162
    def test_mask_corners(self):
163
        """The mask bounding box and the input image have the same center"""
164
        target_shape = 8, 22, 30
165
        transform_center = CropOrPad(target_shape)
166
        transform_mask = CropOrPad(
167
            target_shape, mask_name='label')
168
        mask = self.sample['label'][DATA]
169
        mask *= 0
170
        mask[0, 0, 0, 0] = 1
171
        mask[0, -1, -1, -1] = 1
172
        transformed_center = transform_center(self.sample)
173
        transformed_mask = transform_mask(self.sample)
174
        zipped = zip(transformed_center.values(), transformed_mask.values())
175
        for image_center, image_mask in zipped:
176
            assert_array_equal(
177
                image_center[DATA], image_mask[DATA],
178
                'Data is different after cropping',
179
            )
180
            assert_array_equal(
181
                image_center[AFFINE], image_mask[AFFINE],
182
                'Physical position is different after cropping',
183
            )
184
185
    def test_mask_origin(self):
186
        target_shape = 7, 21, 29
187
        center_voxel = np.floor(np.array(target_shape) / 2).astype(int)
188
        transform_center = CropOrPad(target_shape)
189
        transform_mask = CropOrPad(
190
            target_shape, mask_name='label')
191
        mask = self.sample['label'][DATA]
192
        mask *= 0
193
        mask[0, 0, 0, 0] = 1
194
        transformed_center = transform_center(self.sample)
195
        transformed_mask = transform_mask(self.sample)
196
        zipped = zip(transformed_center.values(), transformed_mask.values())
197
        for image_center, image_mask in zipped:
198
            # Arrays are different
199
            assert not np.array_equal(image_center[DATA], image_mask[DATA])
200
            # Rotation matrix doesn't change
201
            center_rotation = image_center[AFFINE][:3, :3]
202
            mask_rotation = image_mask[AFFINE][:3, :3]
203
            assert_array_equal(center_rotation, mask_rotation)
204
            # Origin does change
205
            center_origin = image_center[AFFINE][:3, 3]
206
            mask_origin = image_mask[AFFINE][:3, 3]
207
            assert not np.array_equal(center_origin, mask_origin)
208
            # Voxel at origin is center of transformed image
209
            origin_value = image_center[DATA][0, 0, 0, 0]
210
            i, j, k = center_voxel
211
            transformed_value = image_mask[DATA][0, i, j, k]
212
            self.assertEqual(origin_value, transformed_value)
213