|
1
|
|
|
import copy |
|
2
|
|
|
import torch |
|
3
|
|
|
import numpy as np |
|
4
|
|
|
import nibabel as nib |
|
5
|
|
|
import torchio as tio |
|
6
|
|
|
import SimpleITK as sitk |
|
7
|
|
|
from ..utils import TorchioTestCase |
|
8
|
|
|
|
|
9
|
|
|
|
|
10
|
|
|
class TestTransforms(TorchioTestCase): |
|
11
|
|
|
"""Tests for all transforms.""" |
|
12
|
|
|
|
|
13
|
|
|
def get_transform(self, channels, is_3d=True, labels=True): |
|
14
|
|
|
landmarks_dict = { |
|
15
|
|
|
channel: np.linspace(0, 100, 13) for channel in channels |
|
16
|
|
|
} |
|
17
|
|
|
disp = 1 if is_3d else (1, 1, 0.01) |
|
18
|
|
|
elastic = tio.RandomElasticDeformation(max_displacement=disp) |
|
19
|
|
|
cp_args = (9, 21, 30) if is_3d else (21, 30, 1) |
|
20
|
|
|
flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1) |
|
21
|
|
|
swap_patch = (2, 3, 4) if is_3d else (3, 4, 1) |
|
22
|
|
|
pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6) |
|
23
|
|
|
crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4) |
|
24
|
|
|
remapping = {1: 2, 2: 1, 3: 20, 4: 25} |
|
25
|
|
|
transforms = [ |
|
26
|
|
|
tio.CropOrPad(cp_args), |
|
27
|
|
|
tio.ToCanonical(), |
|
28
|
|
|
tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample), |
|
29
|
|
|
tio.EnsureShapeMultiple(2, method='crop'), |
|
30
|
|
|
tio.Resample((1, 1.1, 1.25)), |
|
31
|
|
|
tio.RandomFlip(axes=flip_axes, flip_probability=1), |
|
32
|
|
|
tio.RandomMotion(), |
|
33
|
|
|
tio.RandomGhosting(axes=(0, 1, 2)), |
|
34
|
|
|
tio.RandomSpike(), |
|
35
|
|
|
tio.RandomNoise(), |
|
36
|
|
|
tio.RandomBlur(), |
|
37
|
|
|
tio.RandomSwap(patch_size=swap_patch, num_iterations=5), |
|
38
|
|
|
tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY), |
|
39
|
|
|
tio.RandomBiasField(), |
|
40
|
|
|
tio.RescaleIntensity(out_min_max=(0, 1)), |
|
41
|
|
|
tio.ZNormalization(), |
|
42
|
|
|
tio.HistogramStandardization(landmarks_dict), |
|
43
|
|
|
elastic, |
|
44
|
|
|
tio.RandomAffine(), |
|
45
|
|
|
tio.OneOf({ |
|
46
|
|
|
tio.RandomAffine(): 3, |
|
47
|
|
|
elastic: 1, |
|
48
|
|
|
}), |
|
49
|
|
|
tio.RemapLabels(remapping=remapping, masking_method='Left'), |
|
50
|
|
|
tio.RemoveLabels([1, 3]), |
|
51
|
|
|
tio.SequentialLabels(), |
|
52
|
|
|
tio.Pad(pad_args, padding_mode=3), |
|
53
|
|
|
tio.Crop(crop_args), |
|
54
|
|
|
] |
|
55
|
|
|
if labels: |
|
56
|
|
|
transforms.append(tio.RandomLabelsToImage(label_key='label')) |
|
57
|
|
|
return tio.Compose(transforms) |
|
58
|
|
|
|
|
59
|
|
|
def test_transforms_dict(self): |
|
60
|
|
|
transform = tio.RandomNoise(include=('t1', 't2')) |
|
61
|
|
|
input_dict = {k: v.data for (k, v) in self.sample_subject.items()} |
|
62
|
|
|
transformed = transform(input_dict) |
|
63
|
|
|
self.assertIsInstance(transformed, dict) |
|
64
|
|
|
|
|
65
|
|
|
def test_transforms_dict_no_keys(self): |
|
66
|
|
|
transform = tio.RandomNoise() |
|
67
|
|
|
input_dict = {k: v.data for (k, v) in self.sample_subject.items()} |
|
68
|
|
|
with self.assertRaises(RuntimeError): |
|
69
|
|
|
transform(input_dict) |
|
70
|
|
|
|
|
71
|
|
|
def test_transforms_image(self): |
|
72
|
|
|
transform = self.get_transform( |
|
73
|
|
|
channels=('default_image_name',), labels=False) |
|
74
|
|
|
transformed = transform(self.sample_subject.t1) |
|
75
|
|
|
self.assertIsInstance(transformed, tio.ScalarImage) |
|
76
|
|
|
|
|
77
|
|
|
def test_transforms_tensor(self): |
|
78
|
|
|
tensor = torch.rand(2, 4, 5, 8) |
|
79
|
|
|
transform = self.get_transform( |
|
80
|
|
|
channels=('default_image_name',), labels=False) |
|
81
|
|
|
transformed = transform(tensor) |
|
82
|
|
|
self.assertIsInstance(transformed, torch.Tensor) |
|
83
|
|
|
|
|
84
|
|
|
def test_transforms_array(self): |
|
85
|
|
|
tensor = torch.rand(2, 4, 5, 8).numpy() |
|
86
|
|
|
transform = self.get_transform( |
|
87
|
|
|
channels=('default_image_name',), labels=False) |
|
88
|
|
|
transformed = transform(tensor) |
|
89
|
|
|
self.assertIsInstance(transformed, np.ndarray) |
|
90
|
|
|
|
|
91
|
|
|
def test_transforms_sitk(self): |
|
92
|
|
|
tensor = torch.rand(2, 4, 5, 8) |
|
93
|
|
|
affine = np.diag((-1, 2, -3, 1)) |
|
94
|
|
|
image = tio.data.io.nib_to_sitk(tensor, affine) |
|
95
|
|
|
transform = self.get_transform( |
|
96
|
|
|
channels=('default_image_name',), labels=False) |
|
97
|
|
|
transformed = transform(image) |
|
98
|
|
|
self.assertIsInstance(transformed, sitk.Image) |
|
99
|
|
|
|
|
100
|
|
|
def test_transforms_nib(self): |
|
101
|
|
|
data = torch.rand(1, 4, 5, 8).numpy() |
|
102
|
|
|
affine = np.diag((1, -2, 3, 1)) |
|
103
|
|
|
image = nib.Nifti1Image(data, affine) |
|
104
|
|
|
transform = self.get_transform( |
|
105
|
|
|
channels=('default_image_name',), labels=False) |
|
106
|
|
|
transformed = transform(image) |
|
107
|
|
|
self.assertIsInstance(transformed, nib.Nifti1Image) |
|
108
|
|
|
|
|
109
|
|
|
def test_transforms_subject_3d(self): |
|
110
|
|
|
transform = self.get_transform(channels=('t1', 't2'), is_3d=True) |
|
111
|
|
|
transformed = transform(self.sample_subject) |
|
112
|
|
|
self.assertIsInstance(transformed, tio.Subject) |
|
113
|
|
|
|
|
114
|
|
|
def test_transforms_subject_2d(self): |
|
115
|
|
|
transform = self.get_transform(channels=('t1', 't2'), is_3d=False) |
|
116
|
|
|
subject = self.make_2d(self.sample_subject) |
|
117
|
|
|
transformed = transform(subject) |
|
118
|
|
|
self.assertIsInstance(transformed, tio.Subject) |
|
119
|
|
|
|
|
120
|
|
|
def test_transforms_subject_4d(self): |
|
121
|
|
|
composed = self.get_transform(channels=('t1', 't2'), is_3d=True) |
|
122
|
|
|
subject = self.make_multichannel(self.sample_subject) |
|
123
|
|
|
subject = self.flip_affine_x(subject) |
|
124
|
|
|
transformed = None |
|
125
|
|
|
for transform in composed.transforms: |
|
126
|
|
|
repr(transform) # cover __repr__ |
|
127
|
|
|
transformed = transform(subject) |
|
128
|
|
|
trsf_channels = len(transformed.t1.data) |
|
129
|
|
|
assert trsf_channels > 1, f'Lost channels in {transform.name}' |
|
130
|
|
|
exclude = ( |
|
131
|
|
|
'RandomLabelsToImage', |
|
132
|
|
|
'RemapLabels', |
|
133
|
|
|
'RemoveLabels', |
|
134
|
|
|
'SequentialLabels', |
|
135
|
|
|
) |
|
136
|
|
|
if transform.name not in exclude: |
|
137
|
|
|
self.assertEqual( |
|
138
|
|
|
subject.shape[0], |
|
139
|
|
|
transformed.shape[0], |
|
140
|
|
|
f'Different number of channels after {transform.name}' |
|
141
|
|
|
) |
|
142
|
|
|
self.assertTensorNotEqual( |
|
143
|
|
|
subject.t1.data[1], |
|
144
|
|
|
transformed.t1.data[1], |
|
145
|
|
|
f'No changes after {transform.name}' |
|
146
|
|
|
) |
|
147
|
|
|
subject = transformed |
|
148
|
|
|
self.assertIsInstance(transformed, tio.Subject) |
|
149
|
|
|
|
|
150
|
|
|
def test_transform_noop(self): |
|
151
|
|
|
transform = tio.RandomMotion(p=0) |
|
152
|
|
|
transformed = transform(self.sample_subject) |
|
153
|
|
|
self.assertIs(transformed, self.sample_subject) |
|
154
|
|
|
tensor = torch.rand(2, 4, 5, 8).numpy() |
|
155
|
|
|
transformed = transform(tensor) |
|
156
|
|
|
self.assertIs(transformed, tensor) |
|
157
|
|
|
|
|
158
|
|
|
def test_original_unchanged(self): |
|
159
|
|
|
subject = copy.deepcopy(self.sample_subject) |
|
160
|
|
|
composed = self.get_transform(channels=('t1', 't2'), is_3d=True) |
|
161
|
|
|
subject = self.flip_affine_x(subject) |
|
162
|
|
|
for transform in composed.transforms: |
|
163
|
|
|
original_data = copy.deepcopy(subject.t1.data) |
|
164
|
|
|
transform(subject) |
|
165
|
|
|
self.assertTensorEqual( |
|
166
|
|
|
subject.t1.data, |
|
167
|
|
|
original_data, |
|
168
|
|
|
f'Changes after {transform.name}' |
|
169
|
|
|
) |
|
170
|
|
|
|
|
171
|
|
|
def test_transforms_use_include(self): |
|
172
|
|
|
original_subject = copy.deepcopy(self.sample_subject) |
|
173
|
|
|
transform = tio.RandomNoise(include=['t1']) |
|
174
|
|
|
transformed = transform(self.sample_subject) |
|
175
|
|
|
|
|
176
|
|
|
self.assertTensorNotEqual( |
|
177
|
|
|
original_subject.t1.data, |
|
178
|
|
|
transformed.t1.data, |
|
179
|
|
|
f'Changes after {transform.name}' |
|
180
|
|
|
) |
|
181
|
|
|
|
|
182
|
|
|
self.assertTensorEqual( |
|
183
|
|
|
original_subject.t2.data, |
|
184
|
|
|
transformed.t2.data, |
|
185
|
|
|
f'Changes after {transform.name}' |
|
186
|
|
|
) |
|
187
|
|
|
|
|
188
|
|
|
def test_transforms_use_exclude(self): |
|
189
|
|
|
original_subject = copy.deepcopy(self.sample_subject) |
|
190
|
|
|
transform = tio.RandomNoise(exclude=['t2']) |
|
191
|
|
|
transformed = transform(self.sample_subject) |
|
192
|
|
|
|
|
193
|
|
|
self.assertTensorNotEqual( |
|
194
|
|
|
original_subject.t1.data, |
|
195
|
|
|
transformed.t1.data, |
|
196
|
|
|
f'Changes after {transform.name}' |
|
197
|
|
|
) |
|
198
|
|
|
|
|
199
|
|
|
self.assertTensorEqual( |
|
200
|
|
|
original_subject.t2.data, |
|
201
|
|
|
transformed.t2.data, |
|
202
|
|
|
f'Changes after {transform.name}' |
|
203
|
|
|
) |
|
204
|
|
|
|
|
205
|
|
|
def test_transforms_use_include_and_exclude(self): |
|
206
|
|
|
with self.assertRaises(ValueError): |
|
207
|
|
|
tio.RandomNoise(include=['t2'], exclude=['t1']) |
|
208
|
|
|
|
|
209
|
|
|
def test_keys_deprecated(self): |
|
210
|
|
|
with self.assertWarns(UserWarning): |
|
211
|
|
|
tio.RandomNoise(keys=['t2']) |
|
212
|
|
|
|
|
213
|
|
|
def test_keep_original(self): |
|
214
|
|
|
subject = copy.deepcopy(self.sample_subject) |
|
215
|
|
|
old, new = 't1', 't1_original' |
|
216
|
|
|
transformed = tio.RandomAffine(keep={old: new})(subject) |
|
217
|
|
|
assert old in transformed |
|
218
|
|
|
assert new in transformed |
|
219
|
|
|
self.assertTensorEqual( |
|
220
|
|
|
transformed[new].data, |
|
221
|
|
|
subject[old].data, |
|
222
|
|
|
) |
|
223
|
|
|
self.assertTensorNotEqual( |
|
224
|
|
|
transformed[new].data, |
|
225
|
|
|
transformed[old].data, |
|
226
|
|
|
) |
|
227
|
|
|
|
|
228
|
|
|
|
|
229
|
|
|
class TestTransform(TorchioTestCase): |
|
230
|
|
|
|
|
231
|
|
|
def test_abstract_transform(self): |
|
232
|
|
|
with self.assertRaises(TypeError): |
|
233
|
|
|
tio.Transform() |
|
234
|
|
|
|
|
235
|
|
|
def test_arguments_are_not_dict(self): |
|
236
|
|
|
transform = tio.Noise(0, 1, 0) |
|
237
|
|
|
assert not transform.arguments_are_dict() |
|
238
|
|
|
|
|
239
|
|
|
def test_arguments_are_dict(self): |
|
240
|
|
|
transform = tio.Noise({'im': 0}, {'im': 1}, {'im': 0}) |
|
241
|
|
|
assert transform.arguments_are_dict() |
|
242
|
|
|
|
|
243
|
|
|
def test_arguments_are_and_are_not_dict(self): |
|
244
|
|
|
transform = tio.Noise(0, {'im': 1}, {'im': 0}) |
|
245
|
|
|
with self.assertRaises(ValueError): |
|
246
|
|
|
transform.arguments_are_dict() |
|
247
|
|
|
|
|
248
|
|
|
def test_bad_over_max(self): |
|
249
|
|
|
transform = tio.RandomNoise() |
|
250
|
|
|
with self.assertRaises(ValueError): |
|
251
|
|
|
transform._parse_range(2, 'name', max_constraint=1) |
|
252
|
|
|
|
|
253
|
|
|
def test_bad_over_max_range(self): |
|
254
|
|
|
transform = tio.RandomNoise() |
|
255
|
|
|
with self.assertRaises(ValueError): |
|
256
|
|
|
transform._parse_range((0, 2), 'name', max_constraint=1) |
|
257
|
|
|
|
|
258
|
|
|
def test_bad_type(self): |
|
259
|
|
|
transform = tio.RandomNoise() |
|
260
|
|
|
with self.assertRaises(ValueError): |
|
261
|
|
|
transform._parse_range(2.5, 'name', type_constraint=int) |
|
262
|
|
|
|
|
263
|
|
|
def test_no_numbers(self): |
|
264
|
|
|
transform = tio.RandomNoise() |
|
265
|
|
|
with self.assertRaises(ValueError): |
|
266
|
|
|
transform._parse_range('j', 'name') |
|
267
|
|
|
|
|
268
|
|
|
def test_apply_transform_missing(self): |
|
269
|
|
|
class T(tio.Transform): |
|
270
|
|
|
pass |
|
271
|
|
|
with self.assertRaises(TypeError): |
|
272
|
|
|
T().apply_transform(0) |
|
273
|
|
|
|
|
274
|
|
|
def test_non_invertible(self): |
|
275
|
|
|
transform = tio.RandomBlur() |
|
276
|
|
|
with self.assertRaises(RuntimeError): |
|
277
|
|
|
transform.inverse() |
|
278
|
|
|
|
|
279
|
|
|
def test_bad_bounds_mask(self): |
|
280
|
|
|
transform = tio.ZNormalization(masking_method='test') |
|
281
|
|
|
with self.assertRaises(ValueError): |
|
282
|
|
|
transform(self.sample_subject) |
|
283
|
|
|
|
|
284
|
|
|
def test_bounds_mask(self): |
|
285
|
|
|
transform = tio.ZNormalization() |
|
286
|
|
|
with self.assertRaises(ValueError): |
|
287
|
|
|
transform.get_mask_from_anatomical_label('test', 0) |
|
288
|
|
|
tensor = torch.rand((1, 2, 2, 2)) |
|
289
|
|
|
|
|
290
|
|
|
def get_mask(label): |
|
291
|
|
|
mask = transform.get_mask_from_anatomical_label(label, tensor) |
|
292
|
|
|
return mask |
|
293
|
|
|
|
|
294
|
|
|
left = get_mask('Left') |
|
295
|
|
|
assert left[:, 0].sum() == 4 and left[:, 1].sum() == 0 |
|
296
|
|
|
right = get_mask('Right') |
|
297
|
|
|
assert right[:, 1].sum() == 4 and right[:, 0].sum() == 0 |
|
298
|
|
|
posterior = get_mask('Posterior') |
|
299
|
|
|
assert posterior[:, :, 0].sum() == 4 and posterior[:, :, 1].sum() == 0 |
|
300
|
|
|
anterior = get_mask('Anterior') |
|
301
|
|
|
assert anterior[:, :, 1].sum() == 4 and anterior[:, :, 0].sum() == 0 |
|
302
|
|
|
inferior = get_mask('Inferior') |
|
303
|
|
|
assert inferior[..., 0].sum() == 4 and inferior[..., 1].sum() == 0 |
|
304
|
|
|
superior = get_mask('Superior') |
|
305
|
|
|
assert superior[..., 1].sum() == 4 and superior[..., 0].sum() == 0 |
|
306
|
|
|
|
|
307
|
|
|
mask = transform.get_mask_from_bounds(3 * (0, 1), tensor) |
|
308
|
|
|
assert mask[0, 0, 0, 0] == 1 |
|
309
|
|
|
assert mask.sum() == 1 |
|
310
|
|
|
|