1
|
|
|
import copy |
2
|
|
|
|
3
|
|
|
import numpy as np |
4
|
|
|
import pytest |
5
|
|
|
import SimpleITK as sitk |
6
|
|
|
import torch |
7
|
|
|
from nibabel.nifti1 import Nifti1Image |
8
|
|
|
|
9
|
|
|
import torchio as tio |
10
|
|
|
|
11
|
|
|
from ..utils import TorchioTestCase |
12
|
|
|
|
13
|
|
|
|
14
|
|
|
class TestTransforms(TorchioTestCase): |
15
|
|
|
"""Tests for all transforms.""" |
16
|
|
|
|
17
|
|
|
def get_transform(self, channels, is_3d=True, labels=True): |
18
|
|
|
landmarks_dict = {channel: np.linspace(0, 100, 13) for channel in channels} |
19
|
|
|
disp = 1 if is_3d else (1, 1, 0.01) |
20
|
|
|
elastic = tio.RandomElasticDeformation(max_displacement=disp) |
21
|
|
|
affine_elastic = tio.RandomAffineElasticDeformation( |
22
|
|
|
elastic_kwargs={'max_displacement': disp} |
23
|
|
|
) |
24
|
|
|
cp_args = (9, 21, 30) if is_3d else (21, 30, 1) |
25
|
|
|
resize_args = (10, 20, 30) if is_3d else (10, 20, 1) |
26
|
|
|
flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1) |
27
|
|
|
swap_patch = (2, 3, 4) if is_3d else (3, 4, 1) |
28
|
|
|
pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6) |
29
|
|
|
crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4) |
30
|
|
|
remapping = {1: 2, 2: 1, 3: 20, 4: 25} |
31
|
|
|
transforms = [ |
32
|
|
|
tio.CropOrPad(cp_args), |
33
|
|
|
tio.EnsureShapeMultiple(2, method='crop'), |
34
|
|
|
tio.Resize(resize_args), |
35
|
|
|
tio.ToCanonical(), |
36
|
|
|
tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample), |
37
|
|
|
tio.CopyAffine(channels[0]), |
38
|
|
|
tio.Resample((1, 1.1, 1.25)), |
39
|
|
|
tio.RandomFlip(axes=flip_axes, flip_probability=1), |
40
|
|
|
tio.RandomMotion(), |
41
|
|
|
tio.RandomGhosting(axes=(0, 1, 2)), |
42
|
|
|
tio.RandomSpike(), |
43
|
|
|
tio.RandomNoise(), |
44
|
|
|
tio.RandomBlur(), |
45
|
|
|
tio.RandomSwap(patch_size=swap_patch, num_iterations=5), |
46
|
|
|
tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY), |
47
|
|
|
tio.RandomBiasField(), |
48
|
|
|
tio.RescaleIntensity(out_min_max=(0, 1)), |
49
|
|
|
tio.ZNormalization(), |
50
|
|
|
tio.HistogramStandardization(landmarks_dict), |
51
|
|
|
elastic, |
52
|
|
|
tio.RandomAffine(), |
53
|
|
|
affine_elastic, |
54
|
|
|
tio.OneOf( |
55
|
|
|
{ |
56
|
|
|
tio.RandomAffine(): 3, |
57
|
|
|
elastic: 1, |
58
|
|
|
} |
59
|
|
|
), |
60
|
|
|
tio.RemapLabels(remapping=remapping, masking_method='Left'), |
61
|
|
|
tio.RemoveLabels([1, 3]), |
62
|
|
|
tio.SequentialLabels(), |
63
|
|
|
tio.Pad(pad_args, padding_mode=3), |
64
|
|
|
tio.Crop(crop_args), |
65
|
|
|
] |
66
|
|
|
if labels: |
67
|
|
|
transforms.append(tio.RandomLabelsToImage(label_key='label')) |
68
|
|
|
return tio.Compose(transforms) |
69
|
|
|
|
70
|
|
|
def test_transforms_dict(self): |
71
|
|
|
transform = tio.RandomNoise(include=('t1', 't2')) |
72
|
|
|
input_dict = {k: v.data for (k, v) in self.sample_subject.items()} |
73
|
|
|
transformed = transform(input_dict) |
74
|
|
|
assert isinstance(transformed, dict) |
75
|
|
|
|
76
|
|
|
def test_transforms_dict_no_keys(self): |
77
|
|
|
transform = tio.RandomNoise() |
78
|
|
|
input_dict = {k: v.data for (k, v) in self.sample_subject.items()} |
79
|
|
|
with pytest.raises(RuntimeError): |
80
|
|
|
transform(input_dict) |
81
|
|
|
|
82
|
|
|
def test_transforms_image(self): |
83
|
|
|
transform = self.get_transform( |
84
|
|
|
channels=('default_image_name',), |
85
|
|
|
labels=False, |
86
|
|
|
) |
87
|
|
|
transformed = transform(self.sample_subject.t1) |
88
|
|
|
assert isinstance(transformed, tio.ScalarImage) |
89
|
|
|
|
90
|
|
|
def test_transforms_tensor(self): |
91
|
|
|
tensor = torch.rand(2, 4, 5, 8) |
92
|
|
|
transform = self.get_transform( |
93
|
|
|
channels=('default_image_name',), |
94
|
|
|
labels=False, |
95
|
|
|
) |
96
|
|
|
transformed = transform(tensor) |
97
|
|
|
assert isinstance(transformed, torch.Tensor) |
98
|
|
|
|
99
|
|
|
def test_transforms_array(self): |
100
|
|
|
tensor = torch.rand(2, 4, 5, 8).numpy() |
101
|
|
|
transform = self.get_transform( |
102
|
|
|
channels=('default_image_name',), |
103
|
|
|
labels=False, |
104
|
|
|
) |
105
|
|
|
transformed = transform(tensor) |
106
|
|
|
assert isinstance(transformed, np.ndarray) |
107
|
|
|
|
108
|
|
|
def test_transforms_sitk(self): |
109
|
|
|
tensor = torch.rand(2, 4, 5, 8) |
110
|
|
|
affine = np.diag((-1, 2, -3, 1)) |
111
|
|
|
image = tio.data.io.nib_to_sitk(tensor, affine) |
112
|
|
|
transform = self.get_transform( |
113
|
|
|
channels=('default_image_name',), |
114
|
|
|
labels=False, |
115
|
|
|
) |
116
|
|
|
transformed = transform(image) |
117
|
|
|
assert isinstance(transformed, sitk.Image) |
118
|
|
|
|
119
|
|
|
def test_transforms_subject_3d(self): |
120
|
|
|
transform = self.get_transform(channels=('t1', 't2'), is_3d=True) |
121
|
|
|
transformed = transform(self.sample_subject) |
122
|
|
|
assert isinstance(transformed, tio.Subject) |
123
|
|
|
|
124
|
|
|
def test_transforms_subject_2d(self): |
125
|
|
|
transform = self.get_transform(channels=('t1', 't2'), is_3d=False) |
126
|
|
|
subject = self.make_2d(self.sample_subject) |
127
|
|
|
transformed = transform(subject) |
128
|
|
|
assert isinstance(transformed, tio.Subject) |
129
|
|
|
|
130
|
|
|
def test_transforms_subject_4d(self): |
131
|
|
|
composed = self.get_transform(channels=('t1', 't2'), is_3d=True) |
132
|
|
|
subject = self.make_multichannel(self.sample_subject) |
133
|
|
|
subject = self.flip_affine_x(subject) |
134
|
|
|
transformed = None |
135
|
|
|
for transform in composed.transforms: |
136
|
|
|
repr(transform) # cover __repr__ |
137
|
|
|
transformed = transform(subject) |
138
|
|
|
trsf_channels = len(transformed.t1.data) |
139
|
|
|
assert trsf_channels > 1, f'Lost channels in {transform.name}' |
140
|
|
|
exclude = ( |
141
|
|
|
'RandomLabelsToImage', |
142
|
|
|
'RemapLabels', |
143
|
|
|
'RemoveLabels', |
144
|
|
|
'SequentialLabels', |
145
|
|
|
'CopyAffine', |
146
|
|
|
) |
147
|
|
|
if transform.name not in exclude: |
148
|
|
|
assert subject.shape[0] == transformed.shape[0], ( |
149
|
|
|
f'Different number of channels after {transform.name}' |
150
|
|
|
) |
151
|
|
|
self.assert_tensor_not_equal( |
152
|
|
|
subject.t1.data[1], |
153
|
|
|
transformed.t1.data[1], |
154
|
|
|
msg=f'No changes after {transform.name}', |
155
|
|
|
) |
156
|
|
|
subject = transformed |
157
|
|
|
assert isinstance(transformed, tio.Subject) |
158
|
|
|
|
159
|
|
|
def test_transform_noop(self): |
160
|
|
|
transform = tio.RandomMotion(p=0) |
161
|
|
|
transformed = transform(self.sample_subject) |
162
|
|
|
assert transformed is self.sample_subject |
163
|
|
|
tensor = torch.rand(2, 4, 5, 8).numpy() |
164
|
|
|
transformed = transform(tensor) |
165
|
|
|
assert transformed is tensor |
166
|
|
|
|
167
|
|
|
def test_original_unchanged(self): |
168
|
|
|
subject = copy.deepcopy(self.sample_subject) |
169
|
|
|
composed = self.get_transform(channels=('t1', 't2'), is_3d=True) |
170
|
|
|
subject = self.flip_affine_x(subject) |
171
|
|
|
for transform in composed.transforms: |
172
|
|
|
original_data = copy.deepcopy(subject.t1.data) |
173
|
|
|
transform(subject) |
174
|
|
|
self.assert_tensor_equal( |
175
|
|
|
subject.t1.data, |
176
|
|
|
original_data, |
177
|
|
|
msg=f'Changes after {transform.name}', |
178
|
|
|
) |
179
|
|
|
|
180
|
|
|
def test_transforms_use_include(self): |
181
|
|
|
original_subject = copy.deepcopy(self.sample_subject) |
182
|
|
|
transform = tio.RandomNoise(include=['t1']) |
183
|
|
|
transformed = transform(self.sample_subject) |
184
|
|
|
|
185
|
|
|
self.assert_tensor_not_equal( |
186
|
|
|
original_subject.t1.data, |
187
|
|
|
transformed.t1.data, |
188
|
|
|
msg=f'Changes after {transform.name}', |
189
|
|
|
) |
190
|
|
|
|
191
|
|
|
self.assert_tensor_equal( |
192
|
|
|
original_subject.t2.data, |
193
|
|
|
transformed.t2.data, |
194
|
|
|
msg=f'Changes after {transform.name}', |
195
|
|
|
) |
196
|
|
|
|
197
|
|
|
def test_transforms_use_exclude(self): |
198
|
|
|
original_subject = copy.deepcopy(self.sample_subject) |
199
|
|
|
transform = tio.RandomNoise(exclude=['t2']) |
200
|
|
|
transformed = transform(self.sample_subject) |
201
|
|
|
|
202
|
|
|
self.assert_tensor_not_equal( |
203
|
|
|
original_subject.t1.data, |
204
|
|
|
transformed.t1.data, |
205
|
|
|
msg=f'Changes after {transform.name}', |
206
|
|
|
) |
207
|
|
|
|
208
|
|
|
self.assert_tensor_equal( |
209
|
|
|
original_subject.t2.data, |
210
|
|
|
transformed.t2.data, |
211
|
|
|
msg=f'Changes after {transform.name}', |
212
|
|
|
) |
213
|
|
|
|
214
|
|
|
def test_transforms_use_include_and_exclude(self): |
215
|
|
|
with pytest.raises(ValueError): |
216
|
|
|
tio.RandomNoise(include=['t2'], exclude=['t1']) |
217
|
|
|
|
218
|
|
|
def test_keys_deprecated(self): |
219
|
|
|
with pytest.warns(FutureWarning): |
220
|
|
|
tio.RandomNoise(keys=['t2']) |
221
|
|
|
|
222
|
|
|
def test_keep_original(self): |
223
|
|
|
subject = copy.deepcopy(self.sample_subject) |
224
|
|
|
old, new = 't1', 't1_original' |
225
|
|
|
transformed = tio.RandomAffine(keep={old: new})(subject) |
226
|
|
|
assert old in transformed |
227
|
|
|
assert new in transformed |
228
|
|
|
self.assert_tensor_equal( |
229
|
|
|
transformed[new].data, |
230
|
|
|
subject[old].data, |
231
|
|
|
) |
232
|
|
|
self.assert_tensor_not_equal( |
233
|
|
|
transformed[new].data, |
234
|
|
|
transformed[old].data, |
235
|
|
|
) |
236
|
|
|
|
237
|
|
|
|
238
|
|
|
class TestTransform(TorchioTestCase): |
239
|
|
|
def test_abstract_transform(self): |
240
|
|
|
with pytest.raises(TypeError): |
241
|
|
|
tio.Transform() |
242
|
|
|
|
243
|
|
|
def test_arguments_are_not_dict(self): |
244
|
|
|
transform = tio.Noise(0, 1, 0) |
245
|
|
|
assert not transform.arguments_are_dict() |
246
|
|
|
|
247
|
|
|
def test_arguments_are_dict(self): |
248
|
|
|
transform = tio.Noise({'im': 0}, {'im': 1}, {'im': 0}) |
249
|
|
|
assert transform.arguments_are_dict() |
250
|
|
|
|
251
|
|
|
def test_arguments_are_and_are_not_dict(self): |
252
|
|
|
transform = tio.Noise(0, {'im': 1}, {'im': 0}) |
253
|
|
|
with pytest.raises(ValueError): |
254
|
|
|
transform.arguments_are_dict() |
255
|
|
|
|
256
|
|
|
def test_bad_over_max(self): |
257
|
|
|
transform = tio.RandomNoise() |
258
|
|
|
with pytest.raises(ValueError): |
259
|
|
|
transform._parse_range(2, 'name', max_constraint=1) |
260
|
|
|
|
261
|
|
|
def test_bad_over_max_range(self): |
262
|
|
|
transform = tio.RandomNoise() |
263
|
|
|
with pytest.raises(ValueError): |
264
|
|
|
transform._parse_range((0, 2), 'name', max_constraint=1) |
265
|
|
|
|
266
|
|
|
def test_bad_type(self): |
267
|
|
|
transform = tio.RandomNoise() |
268
|
|
|
with pytest.raises(ValueError): |
269
|
|
|
transform._parse_range(2.5, 'name', type_constraint=int) |
270
|
|
|
|
271
|
|
|
def test_no_numbers(self): |
272
|
|
|
transform = tio.RandomNoise() |
273
|
|
|
with pytest.raises(ValueError): |
274
|
|
|
transform._parse_range('j', 'name') |
275
|
|
|
|
276
|
|
|
def test_apply_transform_missing(self): |
277
|
|
|
class T(tio.Transform): |
278
|
|
|
pass |
279
|
|
|
|
280
|
|
|
with pytest.raises(TypeError): |
281
|
|
|
T().apply_transform(0) |
282
|
|
|
|
283
|
|
|
def test_non_invertible(self): |
284
|
|
|
transform = tio.RandomBlur() |
285
|
|
|
with pytest.raises(RuntimeError): |
286
|
|
|
transform.inverse() |
287
|
|
|
|
288
|
|
|
def test_batch_history(self): |
289
|
|
|
# https://github.com/TorchIO-project/torchio/discussions/743 |
290
|
|
|
subject = self.sample_subject |
291
|
|
|
transform = tio.Compose( |
292
|
|
|
[ |
293
|
|
|
tio.RandomAffine(), |
294
|
|
|
tio.CropOrPad(5), |
295
|
|
|
tio.OneHot(), |
296
|
|
|
] |
297
|
|
|
) |
298
|
|
|
dataset = tio.SubjectsDataset([subject], transform=transform) |
299
|
|
|
loader = tio.SubjectsLoader( |
300
|
|
|
dataset, |
301
|
|
|
collate_fn=tio.utils.history_collate, |
302
|
|
|
) |
303
|
|
|
batch = tio.utils.get_first_item(loader) |
304
|
|
|
transformed: tio.Subject = tio.utils.get_subjects_from_batch(batch)[0] |
305
|
|
|
inverse = transformed.apply_inverse_transform() |
306
|
|
|
images1 = subject.get_images(intensity_only=False) |
307
|
|
|
images2 = inverse.get_images(intensity_only=False) |
308
|
|
|
for image1, image2 in zip(images1, images2): |
309
|
|
|
assert image1.shape == image2.shape |
310
|
|
|
|
311
|
|
|
def test_bad_bounds_mask(self): |
312
|
|
|
transform = tio.ZNormalization(masking_method='test') |
313
|
|
|
with pytest.raises(ValueError): |
314
|
|
|
transform(self.sample_subject) |
315
|
|
|
|
316
|
|
|
def test_bounds_mask(self): |
317
|
|
|
transform = tio.ZNormalization() |
318
|
|
|
with pytest.raises(ValueError): |
319
|
|
|
transform.get_mask_from_anatomical_label('test', 0) |
320
|
|
|
tensor = torch.rand((1, 2, 2, 2)) |
321
|
|
|
|
322
|
|
|
def get_mask(label): |
323
|
|
|
mask = transform.get_mask_from_anatomical_label(label, tensor) |
324
|
|
|
return mask |
325
|
|
|
|
326
|
|
|
left = get_mask('Left') |
327
|
|
|
assert left[:, 0].sum() == 4 and left[:, 1].sum() == 0 |
328
|
|
|
right = get_mask('Right') |
329
|
|
|
assert right[:, 1].sum() == 4 and right[:, 0].sum() == 0 |
330
|
|
|
posterior = get_mask('Posterior') |
331
|
|
|
assert posterior[:, :, 0].sum() == 4 and posterior[:, :, 1].sum() == 0 |
332
|
|
|
anterior = get_mask('Anterior') |
333
|
|
|
assert anterior[:, :, 1].sum() == 4 and anterior[:, :, 0].sum() == 0 |
334
|
|
|
inferior = get_mask('Inferior') |
335
|
|
|
assert inferior[..., 0].sum() == 4 and inferior[..., 1].sum() == 0 |
336
|
|
|
superior = get_mask('Superior') |
337
|
|
|
assert superior[..., 1].sum() == 4 and superior[..., 0].sum() == 0 |
338
|
|
|
|
339
|
|
|
mask = transform.get_mask_from_bounds(3 * (0, 1), tensor) |
340
|
|
|
assert mask[0, 0, 0, 0] == 1 |
341
|
|
|
assert mask.sum() == 1 |
342
|
|
|
|
343
|
|
|
def test_label_keys(self): |
344
|
|
|
# Adapted from the issue in which the feature was requested: |
345
|
|
|
# https://github.com/TorchIO-project/torchio/issues/866#issue-1222255576 |
346
|
|
|
size = 1, 10, 10, 10 |
347
|
|
|
image = torch.rand(size) |
348
|
|
|
num_classes = 2 # excluding background |
349
|
|
|
label = torch.randint(num_classes + 1, size) |
350
|
|
|
|
351
|
|
|
data_dict = {'image': image, 'label': label} |
352
|
|
|
|
353
|
|
|
transform = tio.RandomAffine( |
354
|
|
|
include=['image', 'label'], |
355
|
|
|
label_keys=['label'], |
356
|
|
|
) |
357
|
|
|
transformed_label = transform(data_dict)['label'] |
358
|
|
|
|
359
|
|
|
# If the image is indeed transformed as a label map, nearest neighbor |
360
|
|
|
# interpolation is used by default and therefore no intermediate values |
361
|
|
|
# can exist in the output |
362
|
|
|
num_unique_values = len(torch.unique(transformed_label)) |
363
|
|
|
assert num_unique_values <= num_classes + 1 |
364
|
|
|
|
365
|
|
|
def test_nibabel_input(self): |
366
|
|
|
image = self.sample_subject.t1 |
367
|
|
|
image_nib = Nifti1Image(image.data[0].numpy(), image.affine) |
368
|
|
|
transformed = tio.RandomAffine()(image_nib) |
369
|
|
|
transformed.get_fdata() |
370
|
|
|
_ = transformed.affine |
371
|
|
|
|
372
|
|
|
image = self.subject_4d.t1 |
373
|
|
|
tensor_5d = image.data[np.newaxis].permute(2, 3, 4, 0, 1) |
374
|
|
|
image_nib = Nifti1Image(tensor_5d.numpy(), image.affine) |
375
|
|
|
transformed = tio.RandomAffine()(image_nib) |
376
|
|
|
transformed.get_fdata() |
377
|
|
|
_ = transformed.affine |
378
|
|
|
|
379
|
|
|
def test_bad_shape(self): |
380
|
|
|
tensor = torch.rand(1, 2, 3) |
381
|
|
|
with pytest.raises(ValueError, match='must be a 4D tensor'): |
382
|
|
|
tio.RandomAffine()(tensor) |
383
|
|
|
|
384
|
|
|
def test_bad_keys_type(self): |
385
|
|
|
# From https://github.com/TorchIO-project/torchio/issues/923 |
386
|
|
|
with self.assertRaises(ValueError): |
387
|
|
|
tio.RandomAffine(include='t1') |
388
|
|
|
|
389
|
|
|
def test_init_args(self): |
390
|
|
|
transform = tio.Compose([tio.RandomNoise()]) |
391
|
|
|
base_args = transform.get_base_args() |
392
|
|
|
assert 'parse_input' not in base_args |
393
|
|
|
|
394
|
|
|
transform = tio.OneOf([tio.RandomNoise()]) |
395
|
|
|
base_args = transform.get_base_args() |
396
|
|
|
assert 'parse_input' not in base_args |
397
|
|
|
|
398
|
|
|
transform = tio.RandomNoise() |
399
|
|
|
base_args = transform.get_base_args() |
400
|
|
|
assert all( |
401
|
|
|
arg in base_args |
402
|
|
|
for arg in [ |
403
|
|
|
'copy', |
404
|
|
|
'include', |
405
|
|
|
'exclude', |
406
|
|
|
'keep', |
407
|
|
|
'parse_input', |
408
|
|
|
'label_keys', |
409
|
|
|
] |
410
|
|
|
) |
411
|
|
|
|