|
1
|
|
|
import copy |
|
2
|
|
|
|
|
3
|
|
|
import numpy as np |
|
4
|
|
|
import pytest |
|
5
|
|
|
import SimpleITK as sitk |
|
6
|
|
|
import torch |
|
7
|
|
|
from nibabel.nifti1 import Nifti1Image |
|
8
|
|
|
|
|
9
|
|
|
import torchio as tio |
|
10
|
|
|
|
|
11
|
|
|
from ..utils import TorchioTestCase |
|
12
|
|
|
|
|
13
|
|
|
|
|
14
|
|
|
class TestTransforms(TorchioTestCase): |
|
15
|
|
|
"""Tests for all transforms.""" |
|
16
|
|
|
|
|
17
|
|
|
def get_transform(self, channels, is_3d=True, labels=True): |
|
18
|
|
|
landmarks_dict = {channel: np.linspace(0, 100, 13) for channel in channels} |
|
19
|
|
|
disp = 1 if is_3d else (1, 1, 0.01) |
|
20
|
|
|
elastic = tio.RandomElasticDeformation(max_displacement=disp) |
|
21
|
|
|
affine_elastic = tio.RandomAffineElasticDeformation( |
|
22
|
|
|
elastic_kwargs={'max_displacement': disp} |
|
23
|
|
|
) |
|
24
|
|
|
cp_args = (9, 21, 30) if is_3d else (21, 30, 1) |
|
25
|
|
|
resize_args = (10, 20, 30) if is_3d else (10, 20, 1) |
|
26
|
|
|
flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1) |
|
27
|
|
|
swap_patch = (2, 3, 4) if is_3d else (3, 4, 1) |
|
28
|
|
|
pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6) |
|
29
|
|
|
crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4) |
|
30
|
|
|
remapping = {1: 2, 2: 1, 3: 20, 4: 25} |
|
31
|
|
|
transforms = [ |
|
32
|
|
|
tio.CropOrPad(cp_args), |
|
33
|
|
|
tio.EnsureShapeMultiple(2, method='crop'), |
|
34
|
|
|
tio.Resize(resize_args), |
|
35
|
|
|
tio.ToCanonical(), |
|
36
|
|
|
tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample), |
|
37
|
|
|
tio.CopyAffine(channels[0]), |
|
38
|
|
|
tio.Resample((1, 1.1, 1.25)), |
|
39
|
|
|
tio.RandomFlip(axes=flip_axes, flip_probability=1), |
|
40
|
|
|
tio.RandomMotion(), |
|
41
|
|
|
tio.RandomGhosting(axes=(0, 1, 2)), |
|
42
|
|
|
tio.RandomSpike(), |
|
43
|
|
|
tio.RandomNoise(), |
|
44
|
|
|
tio.RandomBlur(), |
|
45
|
|
|
tio.RandomSwap(patch_size=swap_patch, num_iterations=5), |
|
46
|
|
|
tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY), |
|
47
|
|
|
tio.RandomBiasField(), |
|
48
|
|
|
tio.RescaleIntensity(out_min_max=(0, 1)), |
|
49
|
|
|
tio.ZNormalization(), |
|
50
|
|
|
tio.HistogramStandardization(landmarks_dict), |
|
51
|
|
|
elastic, |
|
52
|
|
|
tio.RandomAffine(), |
|
53
|
|
|
affine_elastic, |
|
54
|
|
|
tio.OneOf( |
|
55
|
|
|
{ |
|
56
|
|
|
tio.RandomAffine(): 3, |
|
57
|
|
|
elastic: 1, |
|
58
|
|
|
} |
|
59
|
|
|
), |
|
60
|
|
|
tio.RemapLabels(remapping=remapping, masking_method='Left'), |
|
61
|
|
|
tio.RemoveLabels([1, 3]), |
|
62
|
|
|
tio.SequentialLabels(), |
|
63
|
|
|
tio.Pad(pad_args, padding_mode=3), |
|
64
|
|
|
tio.Crop(crop_args), |
|
65
|
|
|
] |
|
66
|
|
|
if labels: |
|
67
|
|
|
transforms.append(tio.RandomLabelsToImage(label_key='label')) |
|
68
|
|
|
return tio.Compose(transforms) |
|
69
|
|
|
|
|
70
|
|
|
def test_transforms_dict(self): |
|
71
|
|
|
transform = tio.RandomNoise(include=('t1', 't2')) |
|
72
|
|
|
input_dict = {k: v.data for (k, v) in self.sample_subject.items()} |
|
73
|
|
|
transformed = transform(input_dict) |
|
74
|
|
|
assert isinstance(transformed, dict) |
|
75
|
|
|
|
|
76
|
|
|
def test_transforms_dict_no_keys(self): |
|
77
|
|
|
transform = tio.RandomNoise() |
|
78
|
|
|
input_dict = {k: v.data for (k, v) in self.sample_subject.items()} |
|
79
|
|
|
with pytest.raises(RuntimeError): |
|
80
|
|
|
transform(input_dict) |
|
81
|
|
|
|
|
82
|
|
|
def test_transforms_image(self): |
|
83
|
|
|
transform = self.get_transform( |
|
84
|
|
|
channels=('default_image_name',), |
|
85
|
|
|
labels=False, |
|
86
|
|
|
) |
|
87
|
|
|
transformed = transform(self.sample_subject.t1) |
|
88
|
|
|
assert isinstance(transformed, tio.ScalarImage) |
|
89
|
|
|
|
|
90
|
|
|
def test_transforms_tensor(self): |
|
91
|
|
|
tensor = torch.rand(2, 4, 5, 8) |
|
92
|
|
|
transform = self.get_transform( |
|
93
|
|
|
channels=('default_image_name',), |
|
94
|
|
|
labels=False, |
|
95
|
|
|
) |
|
96
|
|
|
transformed = transform(tensor) |
|
97
|
|
|
assert isinstance(transformed, torch.Tensor) |
|
98
|
|
|
|
|
99
|
|
|
def test_transforms_array(self): |
|
100
|
|
|
tensor = torch.rand(2, 4, 5, 8).numpy() |
|
101
|
|
|
transform = self.get_transform( |
|
102
|
|
|
channels=('default_image_name',), |
|
103
|
|
|
labels=False, |
|
104
|
|
|
) |
|
105
|
|
|
transformed = transform(tensor) |
|
106
|
|
|
assert isinstance(transformed, np.ndarray) |
|
107
|
|
|
|
|
108
|
|
|
def test_transforms_sitk(self): |
|
109
|
|
|
tensor = torch.rand(2, 4, 5, 8) |
|
110
|
|
|
affine = np.diag((-1, 2, -3, 1)) |
|
111
|
|
|
image = tio.data.io.nib_to_sitk(tensor, affine) |
|
112
|
|
|
transform = self.get_transform( |
|
113
|
|
|
channels=('default_image_name',), |
|
114
|
|
|
labels=False, |
|
115
|
|
|
) |
|
116
|
|
|
transformed = transform(image) |
|
117
|
|
|
assert isinstance(transformed, sitk.Image) |
|
118
|
|
|
|
|
119
|
|
|
def test_transforms_subject_3d(self): |
|
120
|
|
|
transform = self.get_transform(channels=('t1', 't2'), is_3d=True) |
|
121
|
|
|
transformed = transform(self.sample_subject) |
|
122
|
|
|
assert isinstance(transformed, tio.Subject) |
|
123
|
|
|
|
|
124
|
|
|
def test_transforms_subject_2d(self): |
|
125
|
|
|
transform = self.get_transform(channels=('t1', 't2'), is_3d=False) |
|
126
|
|
|
subject = self.make_2d(self.sample_subject) |
|
127
|
|
|
transformed = transform(subject) |
|
128
|
|
|
assert isinstance(transformed, tio.Subject) |
|
129
|
|
|
|
|
130
|
|
|
def test_transforms_subject_4d(self): |
|
131
|
|
|
composed = self.get_transform(channels=('t1', 't2'), is_3d=True) |
|
132
|
|
|
subject = self.make_multichannel(self.sample_subject) |
|
133
|
|
|
subject = self.flip_affine_x(subject) |
|
134
|
|
|
transformed = None |
|
135
|
|
|
for transform in composed.transforms: |
|
136
|
|
|
repr(transform) # cover __repr__ |
|
137
|
|
|
transformed = transform(subject) |
|
138
|
|
|
trsf_channels = len(transformed.t1.data) |
|
139
|
|
|
assert trsf_channels > 1, f'Lost channels in {transform.name}' |
|
140
|
|
|
exclude = ( |
|
141
|
|
|
'RandomLabelsToImage', |
|
142
|
|
|
'RemapLabels', |
|
143
|
|
|
'RemoveLabels', |
|
144
|
|
|
'SequentialLabels', |
|
145
|
|
|
'CopyAffine', |
|
146
|
|
|
) |
|
147
|
|
|
if transform.name not in exclude: |
|
148
|
|
|
assert subject.shape[0] == transformed.shape[0], ( |
|
149
|
|
|
f'Different number of channels after {transform.name}' |
|
150
|
|
|
) |
|
151
|
|
|
self.assert_tensor_not_equal( |
|
152
|
|
|
subject.t1.data[1], |
|
153
|
|
|
transformed.t1.data[1], |
|
154
|
|
|
msg=f'No changes after {transform.name}', |
|
155
|
|
|
) |
|
156
|
|
|
subject = transformed |
|
157
|
|
|
assert isinstance(transformed, tio.Subject) |
|
158
|
|
|
|
|
159
|
|
|
def test_transform_noop(self): |
|
160
|
|
|
transform = tio.RandomMotion(p=0) |
|
161
|
|
|
transformed = transform(self.sample_subject) |
|
162
|
|
|
assert transformed is self.sample_subject |
|
163
|
|
|
tensor = torch.rand(2, 4, 5, 8).numpy() |
|
164
|
|
|
transformed = transform(tensor) |
|
165
|
|
|
assert transformed is tensor |
|
166
|
|
|
|
|
167
|
|
|
def test_original_unchanged(self): |
|
168
|
|
|
subject = copy.deepcopy(self.sample_subject) |
|
169
|
|
|
composed = self.get_transform(channels=('t1', 't2'), is_3d=True) |
|
170
|
|
|
subject = self.flip_affine_x(subject) |
|
171
|
|
|
for transform in composed.transforms: |
|
172
|
|
|
original_data = copy.deepcopy(subject.t1.data) |
|
173
|
|
|
transform(subject) |
|
174
|
|
|
self.assert_tensor_equal( |
|
175
|
|
|
subject.t1.data, |
|
176
|
|
|
original_data, |
|
177
|
|
|
msg=f'Changes after {transform.name}', |
|
178
|
|
|
) |
|
179
|
|
|
|
|
180
|
|
|
def test_transforms_use_include(self): |
|
181
|
|
|
original_subject = copy.deepcopy(self.sample_subject) |
|
182
|
|
|
transform = tio.RandomNoise(include=['t1']) |
|
183
|
|
|
transformed = transform(self.sample_subject) |
|
184
|
|
|
|
|
185
|
|
|
self.assert_tensor_not_equal( |
|
186
|
|
|
original_subject.t1.data, |
|
187
|
|
|
transformed.t1.data, |
|
188
|
|
|
msg=f'Changes after {transform.name}', |
|
189
|
|
|
) |
|
190
|
|
|
|
|
191
|
|
|
self.assert_tensor_equal( |
|
192
|
|
|
original_subject.t2.data, |
|
193
|
|
|
transformed.t2.data, |
|
194
|
|
|
msg=f'Changes after {transform.name}', |
|
195
|
|
|
) |
|
196
|
|
|
|
|
197
|
|
|
def test_transforms_use_exclude(self): |
|
198
|
|
|
original_subject = copy.deepcopy(self.sample_subject) |
|
199
|
|
|
transform = tio.RandomNoise(exclude=['t2']) |
|
200
|
|
|
transformed = transform(self.sample_subject) |
|
201
|
|
|
|
|
202
|
|
|
self.assert_tensor_not_equal( |
|
203
|
|
|
original_subject.t1.data, |
|
204
|
|
|
transformed.t1.data, |
|
205
|
|
|
msg=f'Changes after {transform.name}', |
|
206
|
|
|
) |
|
207
|
|
|
|
|
208
|
|
|
self.assert_tensor_equal( |
|
209
|
|
|
original_subject.t2.data, |
|
210
|
|
|
transformed.t2.data, |
|
211
|
|
|
msg=f'Changes after {transform.name}', |
|
212
|
|
|
) |
|
213
|
|
|
|
|
214
|
|
|
def test_transforms_use_include_and_exclude(self): |
|
215
|
|
|
with pytest.raises(ValueError): |
|
216
|
|
|
tio.RandomNoise(include=['t2'], exclude=['t1']) |
|
217
|
|
|
|
|
218
|
|
|
def test_keys_deprecated(self): |
|
219
|
|
|
with pytest.warns(FutureWarning): |
|
220
|
|
|
tio.RandomNoise(keys=['t2']) |
|
221
|
|
|
|
|
222
|
|
|
def test_keep_original(self): |
|
223
|
|
|
subject = copy.deepcopy(self.sample_subject) |
|
224
|
|
|
old, new = 't1', 't1_original' |
|
225
|
|
|
transformed = tio.RandomAffine(keep={old: new})(subject) |
|
226
|
|
|
assert old in transformed |
|
227
|
|
|
assert new in transformed |
|
228
|
|
|
self.assert_tensor_equal( |
|
229
|
|
|
transformed[new].data, |
|
230
|
|
|
subject[old].data, |
|
231
|
|
|
) |
|
232
|
|
|
self.assert_tensor_not_equal( |
|
233
|
|
|
transformed[new].data, |
|
234
|
|
|
transformed[old].data, |
|
235
|
|
|
) |
|
236
|
|
|
|
|
237
|
|
|
|
|
238
|
|
|
class TestTransform(TorchioTestCase): |
|
239
|
|
|
def test_abstract_transform(self): |
|
240
|
|
|
with pytest.raises(TypeError): |
|
241
|
|
|
tio.Transform() |
|
242
|
|
|
|
|
243
|
|
|
def test_arguments_are_not_dict(self): |
|
244
|
|
|
transform = tio.Noise(0, 1, 0) |
|
245
|
|
|
assert not transform.arguments_are_dict() |
|
246
|
|
|
|
|
247
|
|
|
def test_arguments_are_dict(self): |
|
248
|
|
|
transform = tio.Noise({'im': 0}, {'im': 1}, {'im': 0}) |
|
249
|
|
|
assert transform.arguments_are_dict() |
|
250
|
|
|
|
|
251
|
|
|
def test_arguments_are_and_are_not_dict(self): |
|
252
|
|
|
transform = tio.Noise(0, {'im': 1}, {'im': 0}) |
|
253
|
|
|
with pytest.raises(ValueError): |
|
254
|
|
|
transform.arguments_are_dict() |
|
255
|
|
|
|
|
256
|
|
|
def test_bad_over_max(self): |
|
257
|
|
|
transform = tio.RandomNoise() |
|
258
|
|
|
with pytest.raises(ValueError): |
|
259
|
|
|
transform._parse_range(2, 'name', max_constraint=1) |
|
260
|
|
|
|
|
261
|
|
|
def test_bad_over_max_range(self): |
|
262
|
|
|
transform = tio.RandomNoise() |
|
263
|
|
|
with pytest.raises(ValueError): |
|
264
|
|
|
transform._parse_range((0, 2), 'name', max_constraint=1) |
|
265
|
|
|
|
|
266
|
|
|
def test_bad_type(self): |
|
267
|
|
|
transform = tio.RandomNoise() |
|
268
|
|
|
with pytest.raises(ValueError): |
|
269
|
|
|
transform._parse_range(2.5, 'name', type_constraint=int) |
|
270
|
|
|
|
|
271
|
|
|
def test_no_numbers(self): |
|
272
|
|
|
transform = tio.RandomNoise() |
|
273
|
|
|
with pytest.raises(ValueError): |
|
274
|
|
|
transform._parse_range('j', 'name') |
|
275
|
|
|
|
|
276
|
|
|
def test_apply_transform_missing(self): |
|
277
|
|
|
class T(tio.Transform): |
|
278
|
|
|
pass |
|
279
|
|
|
|
|
280
|
|
|
with pytest.raises(TypeError): |
|
281
|
|
|
T().apply_transform(0) |
|
282
|
|
|
|
|
283
|
|
|
def test_non_invertible(self): |
|
284
|
|
|
transform = tio.RandomBlur() |
|
285
|
|
|
with pytest.raises(RuntimeError): |
|
286
|
|
|
transform.inverse() |
|
287
|
|
|
|
|
288
|
|
|
def test_batch_history(self): |
|
289
|
|
|
# https://github.com/TorchIO-project/torchio/discussions/743 |
|
290
|
|
|
subject = self.sample_subject |
|
291
|
|
|
transform = tio.Compose( |
|
292
|
|
|
[ |
|
293
|
|
|
tio.RandomAffine(), |
|
294
|
|
|
tio.CropOrPad(5), |
|
295
|
|
|
tio.OneHot(), |
|
296
|
|
|
] |
|
297
|
|
|
) |
|
298
|
|
|
dataset = tio.SubjectsDataset([subject], transform=transform) |
|
299
|
|
|
loader = tio.SubjectsLoader( |
|
300
|
|
|
dataset, |
|
301
|
|
|
collate_fn=tio.utils.history_collate, |
|
302
|
|
|
) |
|
303
|
|
|
batch = tio.utils.get_first_item(loader) |
|
304
|
|
|
transformed: tio.Subject = tio.utils.get_subjects_from_batch(batch)[0] |
|
305
|
|
|
inverse = transformed.apply_inverse_transform() |
|
306
|
|
|
images1 = subject.get_images(intensity_only=False) |
|
307
|
|
|
images2 = inverse.get_images(intensity_only=False) |
|
308
|
|
|
for image1, image2 in zip(images1, images2, strict=True): |
|
309
|
|
|
assert image1.shape == image2.shape |
|
310
|
|
|
|
|
311
|
|
|
def test_bad_bounds_mask(self): |
|
312
|
|
|
transform = tio.ZNormalization(masking_method='test') |
|
313
|
|
|
with pytest.raises(ValueError): |
|
314
|
|
|
transform(self.sample_subject) |
|
315
|
|
|
|
|
316
|
|
|
def test_bounds_mask(self): |
|
317
|
|
|
transform = tio.ZNormalization() |
|
318
|
|
|
with pytest.raises(ValueError): |
|
319
|
|
|
transform.get_mask_from_anatomical_label('test', 0) |
|
320
|
|
|
tensor = torch.rand((1, 2, 2, 2)) |
|
321
|
|
|
|
|
322
|
|
|
def get_mask(label): |
|
323
|
|
|
mask = transform.get_mask_from_anatomical_label(label, tensor) |
|
324
|
|
|
return mask |
|
325
|
|
|
|
|
326
|
|
|
left = get_mask('Left') |
|
327
|
|
|
assert left[:, 0].sum() == 4 and left[:, 1].sum() == 0 |
|
328
|
|
|
right = get_mask('Right') |
|
329
|
|
|
assert right[:, 1].sum() == 4 and right[:, 0].sum() == 0 |
|
330
|
|
|
posterior = get_mask('Posterior') |
|
331
|
|
|
assert posterior[:, :, 0].sum() == 4 and posterior[:, :, 1].sum() == 0 |
|
332
|
|
|
anterior = get_mask('Anterior') |
|
333
|
|
|
assert anterior[:, :, 1].sum() == 4 and anterior[:, :, 0].sum() == 0 |
|
334
|
|
|
inferior = get_mask('Inferior') |
|
335
|
|
|
assert inferior[..., 0].sum() == 4 and inferior[..., 1].sum() == 0 |
|
336
|
|
|
superior = get_mask('Superior') |
|
337
|
|
|
assert superior[..., 1].sum() == 4 and superior[..., 0].sum() == 0 |
|
338
|
|
|
|
|
339
|
|
|
mask = transform.get_mask_from_bounds(3 * (0, 1), tensor) |
|
340
|
|
|
assert mask[0, 0, 0, 0] == 1 |
|
341
|
|
|
assert mask.sum() == 1 |
|
342
|
|
|
|
|
343
|
|
|
def test_label_keys(self): |
|
344
|
|
|
# Adapted from the issue in which the feature was requested: |
|
345
|
|
|
# https://github.com/TorchIO-project/torchio/issues/866#issue-1222255576 |
|
346
|
|
|
size = 1, 10, 10, 10 |
|
347
|
|
|
image = torch.rand(size) |
|
348
|
|
|
num_classes = 2 # excluding background |
|
349
|
|
|
label = torch.randint(num_classes + 1, size) |
|
350
|
|
|
|
|
351
|
|
|
data_dict = {'image': image, 'label': label} |
|
352
|
|
|
|
|
353
|
|
|
transform = tio.RandomAffine( |
|
354
|
|
|
include=['image', 'label'], |
|
355
|
|
|
label_keys=['label'], |
|
356
|
|
|
) |
|
357
|
|
|
transformed_label = transform(data_dict)['label'] |
|
358
|
|
|
|
|
359
|
|
|
# If the image is indeed transformed as a label map, nearest neighbor |
|
360
|
|
|
# interpolation is used by default and therefore no intermediate values |
|
361
|
|
|
# can exist in the output |
|
362
|
|
|
num_unique_values = len(torch.unique(transformed_label)) |
|
363
|
|
|
assert num_unique_values <= num_classes + 1 |
|
364
|
|
|
|
|
365
|
|
|
def test_nibabel_input(self): |
|
366
|
|
|
image = self.sample_subject.t1 |
|
367
|
|
|
image_nib = Nifti1Image(image.data[0].numpy(), image.affine) |
|
368
|
|
|
transformed = tio.RandomAffine()(image_nib) |
|
369
|
|
|
transformed.get_fdata() |
|
370
|
|
|
_ = transformed.affine |
|
371
|
|
|
|
|
372
|
|
|
image = self.subject_4d.t1 |
|
373
|
|
|
tensor_5d = image.data[np.newaxis].permute(2, 3, 4, 0, 1) |
|
374
|
|
|
image_nib = Nifti1Image(tensor_5d.numpy(), image.affine) |
|
375
|
|
|
transformed = tio.RandomAffine()(image_nib) |
|
376
|
|
|
transformed.get_fdata() |
|
377
|
|
|
_ = transformed.affine |
|
378
|
|
|
|
|
379
|
|
|
def test_bad_shape(self): |
|
380
|
|
|
tensor = torch.rand(1, 2, 3) |
|
381
|
|
|
with pytest.raises(ValueError, match='must be a 4D tensor'): |
|
382
|
|
|
tio.RandomAffine()(tensor) |
|
383
|
|
|
|
|
384
|
|
|
def test_bad_keys_type(self): |
|
385
|
|
|
# From https://github.com/TorchIO-project/torchio/issues/923 |
|
386
|
|
|
with self.assertRaises(ValueError): |
|
387
|
|
|
tio.RandomAffine(include='t1') |
|
388
|
|
|
|
|
389
|
|
|
def test_init_args(self): |
|
390
|
|
|
transform = tio.Compose([tio.RandomNoise()]) |
|
391
|
|
|
base_args = transform.get_base_args() |
|
392
|
|
|
assert 'parse_input' not in base_args |
|
393
|
|
|
|
|
394
|
|
|
transform = tio.OneOf([tio.RandomNoise()]) |
|
395
|
|
|
base_args = transform.get_base_args() |
|
396
|
|
|
assert 'parse_input' not in base_args |
|
397
|
|
|
|
|
398
|
|
|
transform = tio.RandomNoise() |
|
399
|
|
|
base_args = transform.get_base_args() |
|
400
|
|
|
assert all( |
|
401
|
|
|
arg in base_args |
|
402
|
|
|
for arg in [ |
|
403
|
|
|
'copy', |
|
404
|
|
|
'include', |
|
405
|
|
|
'exclude', |
|
406
|
|
|
'keep', |
|
407
|
|
|
'parse_input', |
|
408
|
|
|
'label_keys', |
|
409
|
|
|
] |
|
410
|
|
|
) |
|
411
|
|
|
|