tests.transforms.test_transforms   F
last analyzed

Complexity

Total Complexity 63

Size/Duplication

Total Lines 409
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 325
dl 0
loc 409
rs 3.36
c 0
b 0
f 0
wmc 63

35 Methods

Rating   Name   Duplication   Size   Complexity  
A TestTransforms.test_transforms_dict() 0 5 1
A TestTransforms.test_transforms_subject_4d() 0 28 3
A TestTransforms.test_transforms_subject_3d() 0 4 1
A TestTransforms.test_transforms_sitk() 0 10 1
A TestTransforms.test_transform_noop() 0 7 1
A TestTransform.test_non_invertible() 0 4 2
A TestTransforms.test_transforms_tensor() 0 8 1
A TestTransforms.test_keys_deprecated() 0 3 2
A TestTransforms.test_original_unchanged() 0 11 2
A TestTransforms.test_transforms_subject_2d() 0 5 1
A TestTransform.test_bad_over_max_range() 0 4 2
A TestTransforms.test_transforms_use_include() 0 15 1
A TestTransform.test_init_args() 0 20 1
A TestTransform.test_abstract_transform() 0 3 2
A TestTransform.test_batch_history() 0 22 2
A TestTransform.test_bad_over_max() 0 4 2
A TestTransform.test_bad_bounds_mask() 0 4 2
A TestTransforms.test_transforms_image() 0 7 1
A TestTransform.test_bad_type() 0 4 2
A TestTransform.test_nibabel_input() 0 13 1
A TestTransform.test_arguments_are_dict() 0 3 1
A TestTransforms.test_transforms_dict_no_keys() 0 5 2
A TestTransforms.test_keep_original() 0 13 1
A TestTransforms.test_transforms_use_include_and_exclude() 0 3 2
A TestTransform.test_label_keys() 0 21 1
A TestTransform.test_arguments_are_and_are_not_dict() 0 4 2
A TestTransform.test_no_numbers() 0 4 2
A TestTransform.test_apply_transform_missing() 0 6 2
A TestTransforms.test_transforms_use_exclude() 0 15 1
A TestTransforms.test_transforms_array() 0 8 1
C TestTransforms.get_transform() 0 52 10
A TestTransform.test_arguments_are_not_dict() 0 3 1
A TestTransform.test_bad_shape() 0 4 2
A TestTransform.test_bounds_mask() 0 26 2
A TestTransform.test_bad_keys_type() 0 4 2

How to fix   Complexity   

Complexity

Complex classes like tests.transforms.test_transforms often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import copy
2
3
import numpy as np
4
import pytest
5
import SimpleITK as sitk
6
import torch
7
from nibabel.nifti1 import Nifti1Image
8
9
import torchio as tio
10
11
from ..utils import TorchioTestCase
12
13
14
class TestTransforms(TorchioTestCase):
15
    """Tests for all transforms."""
16
17
    def get_transform(self, channels, is_3d=True, labels=True):
18
        landmarks_dict = {channel: np.linspace(0, 100, 13) for channel in channels}
19
        disp = 1 if is_3d else (1, 1, 0.01)
20
        elastic = tio.RandomElasticDeformation(max_displacement=disp)
21
        affine_elastic = tio.RandomAffineElasticDeformation(
22
            elastic_kwargs={'max_displacement': disp}
23
        )
24
        cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
25
        resize_args = (10, 20, 30) if is_3d else (10, 20, 1)
26
        flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
27
        swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
28
        pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
29
        crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
30
        remapping = {1: 2, 2: 1, 3: 20, 4: 25}
31
        transforms = [
32
            tio.CropOrPad(cp_args),
33
            tio.EnsureShapeMultiple(2, method='crop'),
34
            tio.Resize(resize_args),
35
            tio.ToCanonical(),
36
            tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
37
            tio.CopyAffine(channels[0]),
38
            tio.Resample((1, 1.1, 1.25)),
39
            tio.RandomFlip(axes=flip_axes, flip_probability=1),
40
            tio.RandomMotion(),
41
            tio.RandomGhosting(axes=(0, 1, 2)),
42
            tio.RandomSpike(),
43
            tio.RandomNoise(),
44
            tio.RandomBlur(),
45
            tio.RandomSwap(patch_size=swap_patch, num_iterations=5),
46
            tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY),
47
            tio.RandomBiasField(),
48
            tio.RescaleIntensity(out_min_max=(0, 1)),
49
            tio.ZNormalization(),
50
            tio.HistogramStandardization(landmarks_dict),
51
            elastic,
52
            tio.RandomAffine(),
53
            affine_elastic,
54
            tio.OneOf(
55
                {
56
                    tio.RandomAffine(): 3,
57
                    elastic: 1,
58
                }
59
            ),
60
            tio.RemapLabels(remapping=remapping, masking_method='Left'),
61
            tio.RemoveLabels([1, 3]),
62
            tio.SequentialLabels(),
63
            tio.Pad(pad_args, padding_mode=3),
64
            tio.Crop(crop_args),
65
        ]
66
        if labels:
67
            transforms.append(tio.RandomLabelsToImage(label_key='label'))
68
        return tio.Compose(transforms)
69
70
    def test_transforms_dict(self):
71
        transform = tio.RandomNoise(include=('t1', 't2'))
72
        input_dict = {k: v.data for (k, v) in self.sample_subject.items()}
73
        transformed = transform(input_dict)
74
        assert isinstance(transformed, dict)
75
76
    def test_transforms_dict_no_keys(self):
77
        transform = tio.RandomNoise()
78
        input_dict = {k: v.data for (k, v) in self.sample_subject.items()}
79
        with pytest.raises(RuntimeError):
80
            transform(input_dict)
81
82
    def test_transforms_image(self):
83
        transform = self.get_transform(
84
            channels=('default_image_name',),
85
            labels=False,
86
        )
87
        transformed = transform(self.sample_subject.t1)
88
        assert isinstance(transformed, tio.ScalarImage)
89
90
    def test_transforms_tensor(self):
91
        tensor = torch.rand(2, 4, 5, 8)
92
        transform = self.get_transform(
93
            channels=('default_image_name',),
94
            labels=False,
95
        )
96
        transformed = transform(tensor)
97
        assert isinstance(transformed, torch.Tensor)
98
99
    def test_transforms_array(self):
100
        tensor = torch.rand(2, 4, 5, 8).numpy()
101
        transform = self.get_transform(
102
            channels=('default_image_name',),
103
            labels=False,
104
        )
105
        transformed = transform(tensor)
106
        assert isinstance(transformed, np.ndarray)
107
108
    def test_transforms_sitk(self):
109
        tensor = torch.rand(2, 4, 5, 8)
110
        affine = np.diag((-1, 2, -3, 1))
111
        image = tio.data.io.nib_to_sitk(tensor, affine)
112
        transform = self.get_transform(
113
            channels=('default_image_name',),
114
            labels=False,
115
        )
116
        transformed = transform(image)
117
        assert isinstance(transformed, sitk.Image)
118
119
    def test_transforms_subject_3d(self):
120
        transform = self.get_transform(channels=('t1', 't2'), is_3d=True)
121
        transformed = transform(self.sample_subject)
122
        assert isinstance(transformed, tio.Subject)
123
124
    def test_transforms_subject_2d(self):
125
        transform = self.get_transform(channels=('t1', 't2'), is_3d=False)
126
        subject = self.make_2d(self.sample_subject)
127
        transformed = transform(subject)
128
        assert isinstance(transformed, tio.Subject)
129
130
    def test_transforms_subject_4d(self):
131
        composed = self.get_transform(channels=('t1', 't2'), is_3d=True)
132
        subject = self.make_multichannel(self.sample_subject)
133
        subject = self.flip_affine_x(subject)
134
        transformed = None
135
        for transform in composed.transforms:
136
            repr(transform)  # cover __repr__
137
            transformed = transform(subject)
138
            trsf_channels = len(transformed.t1.data)
139
            assert trsf_channels > 1, f'Lost channels in {transform.name}'
140
            exclude = (
141
                'RandomLabelsToImage',
142
                'RemapLabels',
143
                'RemoveLabels',
144
                'SequentialLabels',
145
                'CopyAffine',
146
            )
147
            if transform.name not in exclude:
148
                assert subject.shape[0] == transformed.shape[0], (
149
                    f'Different number of channels after {transform.name}'
150
                )
151
                self.assert_tensor_not_equal(
152
                    subject.t1.data[1],
153
                    transformed.t1.data[1],
154
                    msg=f'No changes after {transform.name}',
155
                )
156
            subject = transformed
157
        assert isinstance(transformed, tio.Subject)
158
159
    def test_transform_noop(self):
160
        transform = tio.RandomMotion(p=0)
161
        transformed = transform(self.sample_subject)
162
        assert transformed is self.sample_subject
163
        tensor = torch.rand(2, 4, 5, 8).numpy()
164
        transformed = transform(tensor)
165
        assert transformed is tensor
166
167
    def test_original_unchanged(self):
168
        subject = copy.deepcopy(self.sample_subject)
169
        composed = self.get_transform(channels=('t1', 't2'), is_3d=True)
170
        subject = self.flip_affine_x(subject)
171
        for transform in composed.transforms:
172
            original_data = copy.deepcopy(subject.t1.data)
173
            transform(subject)
174
            self.assert_tensor_equal(
175
                subject.t1.data,
176
                original_data,
177
                msg=f'Changes after {transform.name}',
178
            )
179
180
    def test_transforms_use_include(self):
181
        original_subject = copy.deepcopy(self.sample_subject)
182
        transform = tio.RandomNoise(include=['t1'])
183
        transformed = transform(self.sample_subject)
184
185
        self.assert_tensor_not_equal(
186
            original_subject.t1.data,
187
            transformed.t1.data,
188
            msg=f'Changes after {transform.name}',
189
        )
190
191
        self.assert_tensor_equal(
192
            original_subject.t2.data,
193
            transformed.t2.data,
194
            msg=f'Changes after {transform.name}',
195
        )
196
197
    def test_transforms_use_exclude(self):
198
        original_subject = copy.deepcopy(self.sample_subject)
199
        transform = tio.RandomNoise(exclude=['t2'])
200
        transformed = transform(self.sample_subject)
201
202
        self.assert_tensor_not_equal(
203
            original_subject.t1.data,
204
            transformed.t1.data,
205
            msg=f'Changes after {transform.name}',
206
        )
207
208
        self.assert_tensor_equal(
209
            original_subject.t2.data,
210
            transformed.t2.data,
211
            msg=f'Changes after {transform.name}',
212
        )
213
214
    def test_transforms_use_include_and_exclude(self):
215
        with pytest.raises(ValueError):
216
            tio.RandomNoise(include=['t2'], exclude=['t1'])
217
218
    def test_keys_deprecated(self):
219
        with pytest.warns(FutureWarning):
220
            tio.RandomNoise(keys=['t2'])
221
222
    def test_keep_original(self):
223
        subject = copy.deepcopy(self.sample_subject)
224
        old, new = 't1', 't1_original'
225
        transformed = tio.RandomAffine(keep={old: new})(subject)
226
        assert old in transformed
227
        assert new in transformed
228
        self.assert_tensor_equal(
229
            transformed[new].data,
230
            subject[old].data,
231
        )
232
        self.assert_tensor_not_equal(
233
            transformed[new].data,
234
            transformed[old].data,
235
        )
236
237
238
class TestTransform(TorchioTestCase):
239
    def test_abstract_transform(self):
240
        with pytest.raises(TypeError):
241
            tio.Transform()
242
243
    def test_arguments_are_not_dict(self):
244
        transform = tio.Noise(0, 1, 0)
245
        assert not transform.arguments_are_dict()
246
247
    def test_arguments_are_dict(self):
248
        transform = tio.Noise({'im': 0}, {'im': 1}, {'im': 0})
249
        assert transform.arguments_are_dict()
250
251
    def test_arguments_are_and_are_not_dict(self):
252
        transform = tio.Noise(0, {'im': 1}, {'im': 0})
253
        with pytest.raises(ValueError):
254
            transform.arguments_are_dict()
255
256
    def test_bad_over_max(self):
257
        transform = tio.RandomNoise()
258
        with pytest.raises(ValueError):
259
            transform._parse_range(2, 'name', max_constraint=1)
260
261
    def test_bad_over_max_range(self):
262
        transform = tio.RandomNoise()
263
        with pytest.raises(ValueError):
264
            transform._parse_range((0, 2), 'name', max_constraint=1)
265
266
    def test_bad_type(self):
267
        transform = tio.RandomNoise()
268
        with pytest.raises(ValueError):
269
            transform._parse_range(2.5, 'name', type_constraint=int)
270
271
    def test_no_numbers(self):
272
        transform = tio.RandomNoise()
273
        with pytest.raises(ValueError):
274
            transform._parse_range('j', 'name')
275
276
    def test_apply_transform_missing(self):
277
        class T(tio.Transform):
278
            pass
279
280
        with pytest.raises(TypeError):
281
            T().apply_transform(0)
282
283
    def test_non_invertible(self):
284
        transform = tio.RandomBlur()
285
        with pytest.raises(RuntimeError):
286
            transform.inverse()
287
288
    def test_batch_history(self):
289
        # https://github.com/TorchIO-project/torchio/discussions/743
290
        subject = self.sample_subject
291
        transform = tio.Compose(
292
            [
293
                tio.RandomAffine(),
294
                tio.CropOrPad(5),
295
                tio.OneHot(),
296
            ]
297
        )
298
        dataset = tio.SubjectsDataset([subject], transform=transform)
299
        loader = tio.SubjectsLoader(
300
            dataset,
301
            collate_fn=tio.utils.history_collate,
302
        )
303
        batch = tio.utils.get_first_item(loader)
304
        transformed: tio.Subject = tio.utils.get_subjects_from_batch(batch)[0]
305
        inverse = transformed.apply_inverse_transform()
306
        images1 = subject.get_images(intensity_only=False)
307
        images2 = inverse.get_images(intensity_only=False)
308
        for image1, image2 in zip(images1, images2):
309
            assert image1.shape == image2.shape
310
311
    def test_bad_bounds_mask(self):
312
        transform = tio.ZNormalization(masking_method='test')
313
        with pytest.raises(ValueError):
314
            transform(self.sample_subject)
315
316
    def test_bounds_mask(self):
317
        transform = tio.ZNormalization()
318
        with pytest.raises(ValueError):
319
            transform.get_mask_from_anatomical_label('test', 0)
320
        tensor = torch.rand((1, 2, 2, 2))
321
322
        def get_mask(label):
323
            mask = transform.get_mask_from_anatomical_label(label, tensor)
324
            return mask
325
326
        left = get_mask('Left')
327
        assert left[:, 0].sum() == 4 and left[:, 1].sum() == 0
328
        right = get_mask('Right')
329
        assert right[:, 1].sum() == 4 and right[:, 0].sum() == 0
330
        posterior = get_mask('Posterior')
331
        assert posterior[:, :, 0].sum() == 4 and posterior[:, :, 1].sum() == 0
332
        anterior = get_mask('Anterior')
333
        assert anterior[:, :, 1].sum() == 4 and anterior[:, :, 0].sum() == 0
334
        inferior = get_mask('Inferior')
335
        assert inferior[..., 0].sum() == 4 and inferior[..., 1].sum() == 0
336
        superior = get_mask('Superior')
337
        assert superior[..., 1].sum() == 4 and superior[..., 0].sum() == 0
338
339
        mask = transform.get_mask_from_bounds(3 * (0, 1), tensor)
340
        assert mask[0, 0, 0, 0] == 1
341
        assert mask.sum() == 1
342
343
    def test_label_keys(self):
344
        # Adapted from the issue in which the feature was requested:
345
        # https://github.com/TorchIO-project/torchio/issues/866#issue-1222255576
346
        size = 1, 10, 10, 10
347
        image = torch.rand(size)
348
        num_classes = 2  # excluding background
349
        label = torch.randint(num_classes + 1, size)
350
351
        data_dict = {'image': image, 'label': label}
352
353
        transform = tio.RandomAffine(
354
            include=['image', 'label'],
355
            label_keys=['label'],
356
        )
357
        transformed_label = transform(data_dict)['label']
358
359
        # If the image is indeed transformed as a label map, nearest neighbor
360
        # interpolation is used by default and therefore no intermediate values
361
        # can exist in the output
362
        num_unique_values = len(torch.unique(transformed_label))
363
        assert num_unique_values <= num_classes + 1
364
365
    def test_nibabel_input(self):
366
        image = self.sample_subject.t1
367
        image_nib = Nifti1Image(image.data[0].numpy(), image.affine)
368
        transformed = tio.RandomAffine()(image_nib)
369
        transformed.get_fdata()
370
        _ = transformed.affine
371
372
        image = self.subject_4d.t1
373
        tensor_5d = image.data[np.newaxis].permute(2, 3, 4, 0, 1)
374
        image_nib = Nifti1Image(tensor_5d.numpy(), image.affine)
375
        transformed = tio.RandomAffine()(image_nib)
376
        transformed.get_fdata()
377
        _ = transformed.affine
378
379
    def test_bad_shape(self):
380
        tensor = torch.rand(1, 2, 3)
381
        with pytest.raises(ValueError, match='must be a 4D tensor'):
382
            tio.RandomAffine()(tensor)
383
384
    def test_bad_keys_type(self):
385
        # From https://github.com/TorchIO-project/torchio/issues/923
386
        with self.assertRaises(ValueError):
387
            tio.RandomAffine(include='t1')
388
389
    def test_init_args(self):
390
        transform = tio.Compose([tio.RandomNoise()])
391
        base_args = transform.get_base_args()
392
        assert 'parse_input' not in base_args
393
394
        transform = tio.OneOf([tio.RandomNoise()])
395
        base_args = transform.get_base_args()
396
        assert 'parse_input' not in base_args
397
398
        transform = tio.RandomNoise()
399
        base_args = transform.get_base_args()
400
        assert all(
401
            arg in base_args
402
            for arg in [
403
                'copy',
404
                'include',
405
                'exclude',
406
                'keep',
407
                'parse_input',
408
                'label_keys',
409
            ]
410
        )
411