1
|
|
|
import copy |
2
|
|
|
import torch |
3
|
|
|
import numpy as np |
4
|
|
|
import torchio as tio |
5
|
|
|
import SimpleITK as sitk |
6
|
|
|
from ..utils import TorchioTestCase |
7
|
|
|
|
8
|
|
|
|
9
|
|
|
class TestTransforms(TorchioTestCase): |
10
|
|
|
"""Tests for all transforms.""" |
11
|
|
|
|
12
|
|
|
def get_transform(self, channels, is_3d=True, labels=True): |
13
|
|
|
landmarks_dict = { |
14
|
|
|
channel: np.linspace(0, 100, 13) for channel in channels |
15
|
|
|
} |
16
|
|
|
disp = 1 if is_3d else (1, 1, 0.01) |
17
|
|
|
elastic = tio.RandomElasticDeformation(max_displacement=disp) |
18
|
|
|
cp_args = (9, 21, 30) if is_3d else (21, 30, 1) |
19
|
|
|
resize_args = (10, 20, 30) if is_3d else (10, 20, 1) |
20
|
|
|
flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1) |
21
|
|
|
swap_patch = (2, 3, 4) if is_3d else (3, 4, 1) |
22
|
|
|
pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6) |
23
|
|
|
crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4) |
24
|
|
|
remapping = {1: 2, 2: 1, 3: 20, 4: 25} |
25
|
|
|
transforms = [ |
26
|
|
|
tio.CropOrPad(cp_args), |
27
|
|
|
tio.EnsureShapeMultiple(2, method='crop'), |
28
|
|
|
tio.Resize(resize_args), |
29
|
|
|
tio.ToCanonical(), |
30
|
|
|
tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample), |
31
|
|
|
tio.CopyAffine(channels[0]), |
32
|
|
|
tio.Resample((1, 1.1, 1.25)), |
33
|
|
|
tio.RandomFlip(axes=flip_axes, flip_probability=1), |
34
|
|
|
tio.RandomMotion(), |
35
|
|
|
tio.RandomGhosting(axes=(0, 1, 2)), |
36
|
|
|
tio.RandomSpike(), |
37
|
|
|
tio.RandomNoise(), |
38
|
|
|
tio.RandomBlur(), |
39
|
|
|
tio.RandomSwap(patch_size=swap_patch, num_iterations=5), |
40
|
|
|
tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY), |
41
|
|
|
tio.RandomBiasField(), |
42
|
|
|
tio.RescaleIntensity(out_min_max=(0, 1)), |
43
|
|
|
tio.ZNormalization(), |
44
|
|
|
tio.HistogramStandardization(landmarks_dict), |
45
|
|
|
elastic, |
46
|
|
|
tio.RandomAffine(), |
47
|
|
|
tio.OneOf({ |
48
|
|
|
tio.RandomAffine(): 3, |
49
|
|
|
elastic: 1, |
50
|
|
|
}), |
51
|
|
|
tio.RemapLabels(remapping=remapping, masking_method='Left'), |
52
|
|
|
tio.RemoveLabels([1, 3]), |
53
|
|
|
tio.SequentialLabels(), |
54
|
|
|
tio.Pad(pad_args, padding_mode=3), |
55
|
|
|
tio.Crop(crop_args), |
56
|
|
|
] |
57
|
|
|
if labels: |
58
|
|
|
transforms.append(tio.RandomLabelsToImage(label_key='label')) |
59
|
|
|
return tio.Compose(transforms) |
60
|
|
|
|
61
|
|
|
def test_transforms_dict(self): |
62
|
|
|
transform = tio.RandomNoise(include=('t1', 't2')) |
63
|
|
|
input_dict = {k: v.data for (k, v) in self.sample_subject.items()} |
64
|
|
|
transformed = transform(input_dict) |
65
|
|
|
self.assertIsInstance(transformed, dict) |
66
|
|
|
|
67
|
|
|
def test_transforms_dict_no_keys(self): |
68
|
|
|
transform = tio.RandomNoise() |
69
|
|
|
input_dict = {k: v.data for (k, v) in self.sample_subject.items()} |
70
|
|
|
with self.assertRaises(RuntimeError): |
71
|
|
|
transform(input_dict) |
72
|
|
|
|
73
|
|
|
def test_transforms_image(self): |
74
|
|
|
transform = self.get_transform( |
75
|
|
|
channels=('default_image_name',), labels=False) |
76
|
|
|
transformed = transform(self.sample_subject.t1) |
77
|
|
|
self.assertIsInstance(transformed, tio.ScalarImage) |
78
|
|
|
|
79
|
|
|
def test_transforms_tensor(self): |
80
|
|
|
tensor = torch.rand(2, 4, 5, 8) |
81
|
|
|
transform = self.get_transform( |
82
|
|
|
channels=('default_image_name',), labels=False) |
83
|
|
|
transformed = transform(tensor) |
84
|
|
|
self.assertIsInstance(transformed, torch.Tensor) |
85
|
|
|
|
86
|
|
|
def test_transforms_array(self): |
87
|
|
|
tensor = torch.rand(2, 4, 5, 8).numpy() |
88
|
|
|
transform = self.get_transform( |
89
|
|
|
channels=('default_image_name',), labels=False) |
90
|
|
|
transformed = transform(tensor) |
91
|
|
|
self.assertIsInstance(transformed, np.ndarray) |
92
|
|
|
|
93
|
|
|
def test_transforms_sitk(self): |
94
|
|
|
tensor = torch.rand(2, 4, 5, 8) |
95
|
|
|
affine = np.diag((-1, 2, -3, 1)) |
96
|
|
|
image = tio.data.io.nib_to_sitk(tensor, affine) |
97
|
|
|
transform = self.get_transform( |
98
|
|
|
channels=('default_image_name',), labels=False) |
99
|
|
|
transformed = transform(image) |
100
|
|
|
self.assertIsInstance(transformed, sitk.Image) |
101
|
|
|
|
102
|
|
|
def test_transforms_subject_3d(self): |
103
|
|
|
transform = self.get_transform(channels=('t1', 't2'), is_3d=True) |
104
|
|
|
transformed = transform(self.sample_subject) |
105
|
|
|
self.assertIsInstance(transformed, tio.Subject) |
106
|
|
|
|
107
|
|
|
def test_transforms_subject_2d(self): |
108
|
|
|
transform = self.get_transform(channels=('t1', 't2'), is_3d=False) |
109
|
|
|
subject = self.make_2d(self.sample_subject) |
110
|
|
|
transformed = transform(subject) |
111
|
|
|
self.assertIsInstance(transformed, tio.Subject) |
112
|
|
|
|
113
|
|
|
def test_transforms_subject_4d(self): |
114
|
|
|
composed = self.get_transform(channels=('t1', 't2'), is_3d=True) |
115
|
|
|
subject = self.make_multichannel(self.sample_subject) |
116
|
|
|
subject = self.flip_affine_x(subject) |
117
|
|
|
transformed = None |
118
|
|
|
for transform in composed.transforms: |
119
|
|
|
repr(transform) # cover __repr__ |
120
|
|
|
transformed = transform(subject) |
121
|
|
|
trsf_channels = len(transformed.t1.data) |
122
|
|
|
assert trsf_channels > 1, f'Lost channels in {transform.name}' |
123
|
|
|
exclude = ( |
124
|
|
|
'RandomLabelsToImage', |
125
|
|
|
'RemapLabels', |
126
|
|
|
'RemoveLabels', |
127
|
|
|
'SequentialLabels', |
128
|
|
|
'CopyAffine', |
129
|
|
|
) |
130
|
|
|
if transform.name not in exclude: |
131
|
|
|
self.assertEqual( |
132
|
|
|
subject.shape[0], |
133
|
|
|
transformed.shape[0], |
134
|
|
|
f'Different number of channels after {transform.name}' |
135
|
|
|
) |
136
|
|
|
self.assertTensorNotEqual( |
137
|
|
|
subject.t1.data[1], |
138
|
|
|
transformed.t1.data[1], |
139
|
|
|
f'No changes after {transform.name}' |
140
|
|
|
) |
141
|
|
|
subject = transformed |
142
|
|
|
self.assertIsInstance(transformed, tio.Subject) |
143
|
|
|
|
144
|
|
|
def test_transform_noop(self): |
145
|
|
|
transform = tio.RandomMotion(p=0) |
146
|
|
|
transformed = transform(self.sample_subject) |
147
|
|
|
self.assertIs(transformed, self.sample_subject) |
148
|
|
|
tensor = torch.rand(2, 4, 5, 8).numpy() |
149
|
|
|
transformed = transform(tensor) |
150
|
|
|
self.assertIs(transformed, tensor) |
151
|
|
|
|
152
|
|
|
def test_original_unchanged(self): |
153
|
|
|
subject = copy.deepcopy(self.sample_subject) |
154
|
|
|
composed = self.get_transform(channels=('t1', 't2'), is_3d=True) |
155
|
|
|
subject = self.flip_affine_x(subject) |
156
|
|
|
for transform in composed.transforms: |
157
|
|
|
original_data = copy.deepcopy(subject.t1.data) |
158
|
|
|
transform(subject) |
159
|
|
|
self.assertTensorEqual( |
160
|
|
|
subject.t1.data, |
161
|
|
|
original_data, |
162
|
|
|
f'Changes after {transform.name}' |
163
|
|
|
) |
164
|
|
|
|
165
|
|
|
def test_transforms_use_include(self): |
166
|
|
|
original_subject = copy.deepcopy(self.sample_subject) |
167
|
|
|
transform = tio.RandomNoise(include=['t1']) |
168
|
|
|
transformed = transform(self.sample_subject) |
169
|
|
|
|
170
|
|
|
self.assertTensorNotEqual( |
171
|
|
|
original_subject.t1.data, |
172
|
|
|
transformed.t1.data, |
173
|
|
|
f'Changes after {transform.name}' |
174
|
|
|
) |
175
|
|
|
|
176
|
|
|
self.assertTensorEqual( |
177
|
|
|
original_subject.t2.data, |
178
|
|
|
transformed.t2.data, |
179
|
|
|
f'Changes after {transform.name}' |
180
|
|
|
) |
181
|
|
|
|
182
|
|
|
def test_transforms_use_exclude(self): |
183
|
|
|
original_subject = copy.deepcopy(self.sample_subject) |
184
|
|
|
transform = tio.RandomNoise(exclude=['t2']) |
185
|
|
|
transformed = transform(self.sample_subject) |
186
|
|
|
|
187
|
|
|
self.assertTensorNotEqual( |
188
|
|
|
original_subject.t1.data, |
189
|
|
|
transformed.t1.data, |
190
|
|
|
f'Changes after {transform.name}' |
191
|
|
|
) |
192
|
|
|
|
193
|
|
|
self.assertTensorEqual( |
194
|
|
|
original_subject.t2.data, |
195
|
|
|
transformed.t2.data, |
196
|
|
|
f'Changes after {transform.name}' |
197
|
|
|
) |
198
|
|
|
|
199
|
|
|
def test_transforms_use_include_and_exclude(self): |
200
|
|
|
with self.assertRaises(ValueError): |
201
|
|
|
tio.RandomNoise(include=['t2'], exclude=['t1']) |
202
|
|
|
|
203
|
|
|
def test_keys_deprecated(self): |
204
|
|
|
with self.assertWarns(UserWarning): |
205
|
|
|
tio.RandomNoise(keys=['t2']) |
206
|
|
|
|
207
|
|
|
def test_keep_original(self): |
208
|
|
|
subject = copy.deepcopy(self.sample_subject) |
209
|
|
|
old, new = 't1', 't1_original' |
210
|
|
|
transformed = tio.RandomAffine(keep={old: new})(subject) |
211
|
|
|
assert old in transformed |
212
|
|
|
assert new in transformed |
213
|
|
|
self.assertTensorEqual( |
214
|
|
|
transformed[new].data, |
215
|
|
|
subject[old].data, |
216
|
|
|
) |
217
|
|
|
self.assertTensorNotEqual( |
218
|
|
|
transformed[new].data, |
219
|
|
|
transformed[old].data, |
220
|
|
|
) |
221
|
|
|
|
222
|
|
|
|
223
|
|
|
class TestTransform(TorchioTestCase): |
224
|
|
|
|
225
|
|
|
def test_abstract_transform(self): |
226
|
|
|
with self.assertRaises(TypeError): |
227
|
|
|
tio.Transform() |
228
|
|
|
|
229
|
|
|
def test_arguments_are_not_dict(self): |
230
|
|
|
transform = tio.Noise(0, 1, 0) |
231
|
|
|
assert not transform.arguments_are_dict() |
232
|
|
|
|
233
|
|
|
def test_arguments_are_dict(self): |
234
|
|
|
transform = tio.Noise({'im': 0}, {'im': 1}, {'im': 0}) |
235
|
|
|
assert transform.arguments_are_dict() |
236
|
|
|
|
237
|
|
|
def test_arguments_are_and_are_not_dict(self): |
238
|
|
|
transform = tio.Noise(0, {'im': 1}, {'im': 0}) |
239
|
|
|
with self.assertRaises(ValueError): |
240
|
|
|
transform.arguments_are_dict() |
241
|
|
|
|
242
|
|
|
def test_bad_over_max(self): |
243
|
|
|
transform = tio.RandomNoise() |
244
|
|
|
with self.assertRaises(ValueError): |
245
|
|
|
transform._parse_range(2, 'name', max_constraint=1) |
246
|
|
|
|
247
|
|
|
def test_bad_over_max_range(self): |
248
|
|
|
transform = tio.RandomNoise() |
249
|
|
|
with self.assertRaises(ValueError): |
250
|
|
|
transform._parse_range((0, 2), 'name', max_constraint=1) |
251
|
|
|
|
252
|
|
|
def test_bad_type(self): |
253
|
|
|
transform = tio.RandomNoise() |
254
|
|
|
with self.assertRaises(ValueError): |
255
|
|
|
transform._parse_range(2.5, 'name', type_constraint=int) |
256
|
|
|
|
257
|
|
|
def test_no_numbers(self): |
258
|
|
|
transform = tio.RandomNoise() |
259
|
|
|
with self.assertRaises(ValueError): |
260
|
|
|
transform._parse_range('j', 'name') |
261
|
|
|
|
262
|
|
|
def test_apply_transform_missing(self): |
263
|
|
|
class T(tio.Transform): |
264
|
|
|
pass |
265
|
|
|
with self.assertRaises(TypeError): |
266
|
|
|
T().apply_transform(0) |
267
|
|
|
|
268
|
|
|
def test_non_invertible(self): |
269
|
|
|
transform = tio.RandomBlur() |
270
|
|
|
with self.assertRaises(RuntimeError): |
271
|
|
|
transform.inverse() |
272
|
|
|
|
273
|
|
|
def test_bad_bounds_mask(self): |
274
|
|
|
transform = tio.ZNormalization(masking_method='test') |
275
|
|
|
with self.assertRaises(ValueError): |
276
|
|
|
transform(self.sample_subject) |
277
|
|
|
|
278
|
|
|
def test_bounds_mask(self): |
279
|
|
|
transform = tio.ZNormalization() |
280
|
|
|
with self.assertRaises(ValueError): |
281
|
|
|
transform.get_mask_from_anatomical_label('test', 0) |
282
|
|
|
tensor = torch.rand((1, 2, 2, 2)) |
283
|
|
|
|
284
|
|
|
def get_mask(label): |
285
|
|
|
mask = transform.get_mask_from_anatomical_label(label, tensor) |
286
|
|
|
return mask |
287
|
|
|
|
288
|
|
|
left = get_mask('Left') |
289
|
|
|
assert left[:, 0].sum() == 4 and left[:, 1].sum() == 0 |
290
|
|
|
right = get_mask('Right') |
291
|
|
|
assert right[:, 1].sum() == 4 and right[:, 0].sum() == 0 |
292
|
|
|
posterior = get_mask('Posterior') |
293
|
|
|
assert posterior[:, :, 0].sum() == 4 and posterior[:, :, 1].sum() == 0 |
294
|
|
|
anterior = get_mask('Anterior') |
295
|
|
|
assert anterior[:, :, 1].sum() == 4 and anterior[:, :, 0].sum() == 0 |
296
|
|
|
inferior = get_mask('Inferior') |
297
|
|
|
assert inferior[..., 0].sum() == 4 and inferior[..., 1].sum() == 0 |
298
|
|
|
superior = get_mask('Superior') |
299
|
|
|
assert superior[..., 1].sum() == 4 and superior[..., 0].sum() == 0 |
300
|
|
|
|
301
|
|
|
mask = transform.get_mask_from_bounds(3 * (0, 1), tensor) |
302
|
|
|
assert mask[0, 0, 0, 0] == 1 |
303
|
|
|
assert mask.sum() == 1 |
304
|
|
|
|