1
|
|
|
import copy |
2
|
|
|
import torch |
3
|
|
|
import numpy as np |
4
|
|
|
import nibabel as nib |
5
|
|
|
import torchio as tio |
6
|
|
|
import SimpleITK as sitk |
7
|
|
|
from ..utils import TorchioTestCase |
8
|
|
|
|
9
|
|
|
|
10
|
|
|
class TestTransforms(TorchioTestCase): |
11
|
|
|
"""Tests for all transforms.""" |
12
|
|
|
|
13
|
|
|
def get_transform(self, channels, is_3d=True, labels=True): |
14
|
|
|
landmarks_dict = { |
15
|
|
|
channel: np.linspace(0, 100, 13) for channel in channels |
16
|
|
|
} |
17
|
|
|
disp = 1 if is_3d else (1, 1, 0.01) |
18
|
|
|
elastic = tio.RandomElasticDeformation(max_displacement=disp) |
19
|
|
|
cp_args = (9, 21, 30) if is_3d else (21, 30, 1) |
20
|
|
|
flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1) |
21
|
|
|
swap_patch = (2, 3, 4) if is_3d else (3, 4, 1) |
22
|
|
|
pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6) |
23
|
|
|
crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4) |
24
|
|
|
remapping = {1: 2, 2: 1, 3: 20, 4: 25} |
25
|
|
|
transforms = [ |
26
|
|
|
tio.CropOrPad(cp_args), |
27
|
|
|
tio.ToCanonical(), |
28
|
|
|
tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample), |
29
|
|
|
tio.EnsureShapeMultiple(2, method='crop'), |
30
|
|
|
tio.Resample((1, 1.1, 1.25)), |
31
|
|
|
tio.RandomFlip(axes=flip_axes, flip_probability=1), |
32
|
|
|
tio.RandomMotion(), |
33
|
|
|
tio.RandomGhosting(axes=(0, 1, 2)), |
34
|
|
|
tio.RandomSpike(), |
35
|
|
|
tio.RandomNoise(), |
36
|
|
|
tio.RandomBlur(), |
37
|
|
|
tio.RandomSwap(patch_size=swap_patch, num_iterations=5), |
38
|
|
|
tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY), |
39
|
|
|
tio.RandomBiasField(), |
40
|
|
|
tio.RescaleIntensity(out_min_max=(0, 1)), |
41
|
|
|
tio.ZNormalization(), |
42
|
|
|
tio.HistogramStandardization(landmarks_dict), |
43
|
|
|
elastic, |
44
|
|
|
tio.RandomAffine(), |
45
|
|
|
tio.OneOf({ |
46
|
|
|
tio.RandomAffine(): 3, |
47
|
|
|
elastic: 1, |
48
|
|
|
}), |
49
|
|
|
tio.RemapLabels(remapping=remapping, masking_method='Left'), |
50
|
|
|
tio.RemoveLabels([1, 3]), |
51
|
|
|
tio.SequentialLabels(), |
52
|
|
|
tio.Pad(pad_args, padding_mode=3), |
53
|
|
|
tio.Crop(crop_args), |
54
|
|
|
] |
55
|
|
|
if labels: |
56
|
|
|
transforms.append(tio.RandomLabelsToImage(label_key='label')) |
57
|
|
|
return tio.Compose(transforms) |
58
|
|
|
|
59
|
|
|
def test_transforms_dict(self): |
60
|
|
|
transform = tio.RandomNoise(include=('t1', 't2')) |
61
|
|
|
input_dict = {k: v.data for (k, v) in self.sample_subject.items()} |
62
|
|
|
transformed = transform(input_dict) |
63
|
|
|
self.assertIsInstance(transformed, dict) |
64
|
|
|
|
65
|
|
|
def test_transforms_dict_no_keys(self): |
66
|
|
|
transform = tio.RandomNoise() |
67
|
|
|
input_dict = {k: v.data for (k, v) in self.sample_subject.items()} |
68
|
|
|
with self.assertRaises(RuntimeError): |
69
|
|
|
transform(input_dict) |
70
|
|
|
|
71
|
|
|
def test_transforms_image(self): |
72
|
|
|
transform = self.get_transform( |
73
|
|
|
channels=('default_image_name',), labels=False) |
74
|
|
|
transformed = transform(self.sample_subject.t1) |
75
|
|
|
self.assertIsInstance(transformed, tio.ScalarImage) |
76
|
|
|
|
77
|
|
|
def test_transforms_tensor(self): |
78
|
|
|
tensor = torch.rand(2, 4, 5, 8) |
79
|
|
|
transform = self.get_transform( |
80
|
|
|
channels=('default_image_name',), labels=False) |
81
|
|
|
transformed = transform(tensor) |
82
|
|
|
self.assertIsInstance(transformed, torch.Tensor) |
83
|
|
|
|
84
|
|
|
def test_transforms_array(self): |
85
|
|
|
tensor = torch.rand(2, 4, 5, 8).numpy() |
86
|
|
|
transform = self.get_transform( |
87
|
|
|
channels=('default_image_name',), labels=False) |
88
|
|
|
transformed = transform(tensor) |
89
|
|
|
self.assertIsInstance(transformed, np.ndarray) |
90
|
|
|
|
91
|
|
|
def test_transforms_sitk(self): |
92
|
|
|
tensor = torch.rand(2, 4, 5, 8) |
93
|
|
|
affine = np.diag((-1, 2, -3, 1)) |
94
|
|
|
image = tio.data.io.nib_to_sitk(tensor, affine) |
95
|
|
|
transform = self.get_transform( |
96
|
|
|
channels=('default_image_name',), labels=False) |
97
|
|
|
transformed = transform(image) |
98
|
|
|
self.assertIsInstance(transformed, sitk.Image) |
99
|
|
|
|
100
|
|
|
def test_transforms_nib(self): |
101
|
|
|
data = torch.rand(1, 4, 5, 8).numpy() |
102
|
|
|
affine = np.diag((1, -2, 3, 1)) |
103
|
|
|
image = nib.Nifti1Image(data, affine) |
104
|
|
|
transform = self.get_transform( |
105
|
|
|
channels=('default_image_name',), labels=False) |
106
|
|
|
transformed = transform(image) |
107
|
|
|
self.assertIsInstance(transformed, nib.Nifti1Image) |
108
|
|
|
|
109
|
|
|
def test_transforms_subject_3d(self): |
110
|
|
|
transform = self.get_transform(channels=('t1', 't2'), is_3d=True) |
111
|
|
|
transformed = transform(self.sample_subject) |
112
|
|
|
self.assertIsInstance(transformed, tio.Subject) |
113
|
|
|
|
114
|
|
|
def test_transforms_subject_2d(self): |
115
|
|
|
transform = self.get_transform(channels=('t1', 't2'), is_3d=False) |
116
|
|
|
subject = self.make_2d(self.sample_subject) |
117
|
|
|
transformed = transform(subject) |
118
|
|
|
self.assertIsInstance(transformed, tio.Subject) |
119
|
|
|
|
120
|
|
|
def test_transforms_subject_4d(self): |
121
|
|
|
composed = self.get_transform(channels=('t1', 't2'), is_3d=True) |
122
|
|
|
subject = self.make_multichannel(self.sample_subject) |
123
|
|
|
subject = self.flip_affine_x(subject) |
124
|
|
|
transformed = None |
125
|
|
|
for transform in composed.transforms: |
126
|
|
|
repr(transform) # cover __repr__ |
127
|
|
|
transformed = transform(subject) |
128
|
|
|
trsf_channels = len(transformed.t1.data) |
129
|
|
|
assert trsf_channels > 1, f'Lost channels in {transform.name}' |
130
|
|
|
exclude = ( |
131
|
|
|
'RandomLabelsToImage', |
132
|
|
|
'RemapLabels', |
133
|
|
|
'RemoveLabels', |
134
|
|
|
'SequentialLabels', |
135
|
|
|
) |
136
|
|
|
if transform.name not in exclude: |
137
|
|
|
self.assertEqual( |
138
|
|
|
subject.shape[0], |
139
|
|
|
transformed.shape[0], |
140
|
|
|
f'Different number of channels after {transform.name}' |
141
|
|
|
) |
142
|
|
|
self.assertTensorNotEqual( |
143
|
|
|
subject.t1.data[1], |
144
|
|
|
transformed.t1.data[1], |
145
|
|
|
f'No changes after {transform.name}' |
146
|
|
|
) |
147
|
|
|
subject = transformed |
148
|
|
|
self.assertIsInstance(transformed, tio.Subject) |
149
|
|
|
|
150
|
|
|
def test_transform_noop(self): |
151
|
|
|
transform = tio.RandomMotion(p=0) |
152
|
|
|
transformed = transform(self.sample_subject) |
153
|
|
|
self.assertIs(transformed, self.sample_subject) |
154
|
|
|
tensor = torch.rand(2, 4, 5, 8).numpy() |
155
|
|
|
transformed = transform(tensor) |
156
|
|
|
self.assertIs(transformed, tensor) |
157
|
|
|
|
158
|
|
|
def test_original_unchanged(self): |
159
|
|
|
subject = copy.deepcopy(self.sample_subject) |
160
|
|
|
composed = self.get_transform(channels=('t1', 't2'), is_3d=True) |
161
|
|
|
subject = self.flip_affine_x(subject) |
162
|
|
|
for transform in composed.transforms: |
163
|
|
|
original_data = copy.deepcopy(subject.t1.data) |
164
|
|
|
transform(subject) |
165
|
|
|
self.assertTensorEqual( |
166
|
|
|
subject.t1.data, |
167
|
|
|
original_data, |
168
|
|
|
f'Changes after {transform.name}' |
169
|
|
|
) |
170
|
|
|
|
171
|
|
|
def test_transforms_use_include(self): |
172
|
|
|
original_subject = copy.deepcopy(self.sample_subject) |
173
|
|
|
transform = tio.RandomNoise(include=['t1']) |
174
|
|
|
transformed = transform(self.sample_subject) |
175
|
|
|
|
176
|
|
|
self.assertTensorNotEqual( |
177
|
|
|
original_subject.t1.data, |
178
|
|
|
transformed.t1.data, |
179
|
|
|
f'Changes after {transform.name}' |
180
|
|
|
) |
181
|
|
|
|
182
|
|
|
self.assertTensorEqual( |
183
|
|
|
original_subject.t2.data, |
184
|
|
|
transformed.t2.data, |
185
|
|
|
f'Changes after {transform.name}' |
186
|
|
|
) |
187
|
|
|
|
188
|
|
|
def test_transforms_use_exclude(self): |
189
|
|
|
original_subject = copy.deepcopy(self.sample_subject) |
190
|
|
|
transform = tio.RandomNoise(exclude=['t2']) |
191
|
|
|
transformed = transform(self.sample_subject) |
192
|
|
|
|
193
|
|
|
self.assertTensorNotEqual( |
194
|
|
|
original_subject.t1.data, |
195
|
|
|
transformed.t1.data, |
196
|
|
|
f'Changes after {transform.name}' |
197
|
|
|
) |
198
|
|
|
|
199
|
|
|
self.assertTensorEqual( |
200
|
|
|
original_subject.t2.data, |
201
|
|
|
transformed.t2.data, |
202
|
|
|
f'Changes after {transform.name}' |
203
|
|
|
) |
204
|
|
|
|
205
|
|
|
def test_transforms_use_include_and_exclude(self): |
206
|
|
|
with self.assertRaises(ValueError): |
207
|
|
|
tio.RandomNoise(include=['t2'], exclude=['t1']) |
208
|
|
|
|
209
|
|
|
def test_keys_deprecated(self): |
210
|
|
|
with self.assertWarns(UserWarning): |
211
|
|
|
tio.RandomNoise(keys=['t2']) |
212
|
|
|
|
213
|
|
|
def test_keep_original(self): |
214
|
|
|
subject = copy.deepcopy(self.sample_subject) |
215
|
|
|
old, new = 't1', 't1_original' |
216
|
|
|
transformed = tio.RandomAffine(keep={old: new})(subject) |
217
|
|
|
assert old in transformed |
218
|
|
|
assert new in transformed |
219
|
|
|
self.assertTensorEqual( |
220
|
|
|
transformed[new].data, |
221
|
|
|
subject[old].data, |
222
|
|
|
) |
223
|
|
|
self.assertTensorNotEqual( |
224
|
|
|
transformed[new].data, |
225
|
|
|
transformed[old].data, |
226
|
|
|
) |
227
|
|
|
|
228
|
|
|
|
229
|
|
|
class TestTransform(TorchioTestCase): |
230
|
|
|
|
231
|
|
|
def test_abstract_transform(self): |
232
|
|
|
with self.assertRaises(TypeError): |
233
|
|
|
tio.Transform() |
234
|
|
|
|
235
|
|
|
def test_arguments_are_not_dict(self): |
236
|
|
|
transform = tio.Noise(0, 1, 0) |
237
|
|
|
assert not transform.arguments_are_dict() |
238
|
|
|
|
239
|
|
|
def test_arguments_are_dict(self): |
240
|
|
|
transform = tio.Noise({'im': 0}, {'im': 1}, {'im': 0}) |
241
|
|
|
assert transform.arguments_are_dict() |
242
|
|
|
|
243
|
|
|
def test_arguments_are_and_are_not_dict(self): |
244
|
|
|
transform = tio.Noise(0, {'im': 1}, {'im': 0}) |
245
|
|
|
with self.assertRaises(ValueError): |
246
|
|
|
transform.arguments_are_dict() |
247
|
|
|
|
248
|
|
|
def test_bad_over_max(self): |
249
|
|
|
transform = tio.RandomNoise() |
250
|
|
|
with self.assertRaises(ValueError): |
251
|
|
|
transform._parse_range(2, 'name', max_constraint=1) |
252
|
|
|
|
253
|
|
|
def test_bad_over_max_range(self): |
254
|
|
|
transform = tio.RandomNoise() |
255
|
|
|
with self.assertRaises(ValueError): |
256
|
|
|
transform._parse_range((0, 2), 'name', max_constraint=1) |
257
|
|
|
|
258
|
|
|
def test_bad_type(self): |
259
|
|
|
transform = tio.RandomNoise() |
260
|
|
|
with self.assertRaises(ValueError): |
261
|
|
|
transform._parse_range(2.5, 'name', type_constraint=int) |
262
|
|
|
|
263
|
|
|
def test_no_numbers(self): |
264
|
|
|
transform = tio.RandomNoise() |
265
|
|
|
with self.assertRaises(ValueError): |
266
|
|
|
transform._parse_range('j', 'name') |
267
|
|
|
|
268
|
|
|
def test_apply_transform_missing(self): |
269
|
|
|
class T(tio.Transform): |
270
|
|
|
pass |
271
|
|
|
with self.assertRaises(TypeError): |
272
|
|
|
T().apply_transform(0) |
273
|
|
|
|
274
|
|
|
def test_non_invertible(self): |
275
|
|
|
transform = tio.RandomBlur() |
276
|
|
|
with self.assertRaises(RuntimeError): |
277
|
|
|
transform.inverse() |
278
|
|
|
|
279
|
|
|
def test_bad_bounds_mask(self): |
280
|
|
|
transform = tio.ZNormalization(masking_method='test') |
281
|
|
|
with self.assertRaises(ValueError): |
282
|
|
|
transform(self.sample_subject) |
283
|
|
|
|
284
|
|
|
def test_bounds_mask(self): |
285
|
|
|
transform = tio.ZNormalization() |
286
|
|
|
with self.assertRaises(ValueError): |
287
|
|
|
transform.get_mask_from_anatomical_label('test', 0) |
288
|
|
|
tensor = torch.rand((1, 2, 2, 2)) |
289
|
|
|
|
290
|
|
|
def get_mask(label): |
291
|
|
|
mask = transform.get_mask_from_anatomical_label(label, tensor) |
292
|
|
|
return mask |
293
|
|
|
|
294
|
|
|
left = get_mask('Left') |
295
|
|
|
assert left[:, 0].sum() == 4 and left[:, 1].sum() == 0 |
296
|
|
|
right = get_mask('Right') |
297
|
|
|
assert right[:, 1].sum() == 4 and right[:, 0].sum() == 0 |
298
|
|
|
posterior = get_mask('Posterior') |
299
|
|
|
assert posterior[:, :, 0].sum() == 4 and posterior[:, :, 1].sum() == 0 |
300
|
|
|
anterior = get_mask('Anterior') |
301
|
|
|
assert anterior[:, :, 1].sum() == 4 and anterior[:, :, 0].sum() == 0 |
302
|
|
|
inferior = get_mask('Inferior') |
303
|
|
|
assert inferior[..., 0].sum() == 4 and inferior[..., 1].sum() == 0 |
304
|
|
|
superior = get_mask('Superior') |
305
|
|
|
assert superior[..., 1].sum() == 4 and superior[..., 0].sum() == 0 |
306
|
|
|
|
307
|
|
|
mask = transform.get_mask_from_bounds(3 * (0, 1), tensor) |
308
|
|
|
assert mask[0, 0, 0, 0] == 1 |
309
|
|
|
assert mask.sum() == 1 |
310
|
|
|
|