Passed
Pull Request — master (#533)
by Fernando
01:23
created

tests.transforms.test_transforms   C

Complexity

Total Complexity 54

Size/Duplication

Total Lines 310
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 54
eloc 255
dl 0
loc 310
rs 6.4799
c 0
b 0
f 0

30 Methods

Rating   Name   Duplication   Size   Complexity  
A TestTransforms.test_transforms_dict() 0 5 1
A TestTransforms.test_transforms_subject_3d() 0 4 1
A TestTransforms.test_transforms_sitk() 0 8 1
A TestTransforms.test_transforms_tensor() 0 6 1
A TestTransforms.test_transforms_subject_2d() 0 5 1
A TestTransforms.test_transforms_image() 0 5 1
A TestTransforms.test_transforms_dict_no_keys() 0 5 2
A TestTransforms.test_transforms_array() 0 6 1
C TestTransforms.get_transform() 0 45 9
A TestTransforms.test_transforms_nib() 0 8 1
A TestTransforms.test_transforms_subject_4d() 0 29 3
A TestTransforms.test_transform_noop() 0 7 1
A TestTransform.test_non_invertible() 0 4 2
A TestTransforms.test_keys_deprecated() 0 3 2
A TestTransforms.test_original_unchanged() 0 11 2
A TestTransform.test_bad_over_max_range() 0 4 2
A TestTransforms.test_transforms_use_include() 0 15 1
A TestTransform.test_abstract_transform() 0 3 2
A TestTransform.test_bad_over_max() 0 4 2
A TestTransform.test_bad_bounds_mask() 0 4 2
A TestTransform.test_bad_type() 0 4 2
A TestTransform.test_arguments_are_dict() 0 3 1
A TestTransforms.test_keep_original() 0 13 1
A TestTransforms.test_transforms_use_include_and_exclude() 0 3 2
A TestTransform.test_no_numbers() 0 4 2
A TestTransform.test_arguments_are_and_are_not_dict() 0 4 2
A TestTransform.test_apply_transform_missing() 0 5 2
A TestTransforms.test_transforms_use_exclude() 0 15 1
A TestTransform.test_arguments_are_not_dict() 0 3 1
A TestTransform.test_bounds_mask() 0 26 2

How to fix   Complexity   

Complexity

Complex classes like tests.transforms.test_transforms often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import copy
2
import torch
3
import numpy as np
4
import nibabel as nib
5
import torchio as tio
6
import SimpleITK as sitk
7
from ..utils import TorchioTestCase
8
9
10
class TestTransforms(TorchioTestCase):
11
    """Tests for all transforms."""
12
13
    def get_transform(self, channels, is_3d=True, labels=True):
14
        landmarks_dict = {
15
            channel: np.linspace(0, 100, 13) for channel in channels
16
        }
17
        disp = 1 if is_3d else (1, 1, 0.01)
18
        elastic = tio.RandomElasticDeformation(max_displacement=disp)
19
        cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
20
        flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
21
        swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
22
        pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
23
        crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
24
        remapping = {1: 2, 2: 1, 3: 20, 4: 25}
25
        transforms = [
26
            tio.CropOrPad(cp_args),
27
            tio.ToCanonical(),
28
            tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
29
            tio.EnsureShapeMultiple(2, method='crop'),
30
            tio.Resample((1, 1.1, 1.25)),
31
            tio.RandomFlip(axes=flip_axes, flip_probability=1),
32
            tio.RandomMotion(),
33
            tio.RandomGhosting(axes=(0, 1, 2)),
34
            tio.RandomSpike(),
35
            tio.RandomNoise(),
36
            tio.RandomBlur(),
37
            tio.RandomSwap(patch_size=swap_patch, num_iterations=5),
38
            tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY),
39
            tio.RandomBiasField(),
40
            tio.RescaleIntensity(out_min_max=(0, 1)),
41
            tio.ZNormalization(),
42
            tio.HistogramStandardization(landmarks_dict),
43
            elastic,
44
            tio.RandomAffine(),
45
            tio.OneOf({
46
                tio.RandomAffine(): 3,
47
                elastic: 1,
48
            }),
49
            tio.RemapLabels(remapping=remapping, masking_method='Left'),
50
            tio.RemoveLabels([1, 3]),
51
            tio.SequentialLabels(),
52
            tio.Pad(pad_args, padding_mode=3),
53
            tio.Crop(crop_args),
54
        ]
55
        if labels:
56
            transforms.append(tio.RandomLabelsToImage(label_key='label'))
57
        return tio.Compose(transforms)
58
59
    def test_transforms_dict(self):
60
        transform = tio.RandomNoise(include=('t1', 't2'))
61
        input_dict = {k: v.data for (k, v) in self.sample_subject.items()}
62
        transformed = transform(input_dict)
63
        self.assertIsInstance(transformed, dict)
64
65
    def test_transforms_dict_no_keys(self):
66
        transform = tio.RandomNoise()
67
        input_dict = {k: v.data for (k, v) in self.sample_subject.items()}
68
        with self.assertRaises(RuntimeError):
69
            transform(input_dict)
70
71
    def test_transforms_image(self):
72
        transform = self.get_transform(
73
            channels=('default_image_name',), labels=False)
74
        transformed = transform(self.sample_subject.t1)
75
        self.assertIsInstance(transformed, tio.ScalarImage)
76
77
    def test_transforms_tensor(self):
78
        tensor = torch.rand(2, 4, 5, 8)
79
        transform = self.get_transform(
80
            channels=('default_image_name',), labels=False)
81
        transformed = transform(tensor)
82
        self.assertIsInstance(transformed, torch.Tensor)
83
84
    def test_transforms_array(self):
85
        tensor = torch.rand(2, 4, 5, 8).numpy()
86
        transform = self.get_transform(
87
            channels=('default_image_name',), labels=False)
88
        transformed = transform(tensor)
89
        self.assertIsInstance(transformed, np.ndarray)
90
91
    def test_transforms_sitk(self):
92
        tensor = torch.rand(2, 4, 5, 8)
93
        affine = np.diag((-1, 2, -3, 1))
94
        image = tio.data.io.nib_to_sitk(tensor, affine)
95
        transform = self.get_transform(
96
            channels=('default_image_name',), labels=False)
97
        transformed = transform(image)
98
        self.assertIsInstance(transformed, sitk.Image)
99
100
    def test_transforms_nib(self):
101
        data = torch.rand(1, 4, 5, 8).numpy()
102
        affine = np.diag((1, -2, 3, 1))
103
        image = nib.Nifti1Image(data, affine)
104
        transform = self.get_transform(
105
            channels=('default_image_name',), labels=False)
106
        transformed = transform(image)
107
        self.assertIsInstance(transformed, nib.Nifti1Image)
108
109
    def test_transforms_subject_3d(self):
110
        transform = self.get_transform(channels=('t1', 't2'), is_3d=True)
111
        transformed = transform(self.sample_subject)
112
        self.assertIsInstance(transformed, tio.Subject)
113
114
    def test_transforms_subject_2d(self):
115
        transform = self.get_transform(channels=('t1', 't2'), is_3d=False)
116
        subject = self.make_2d(self.sample_subject)
117
        transformed = transform(subject)
118
        self.assertIsInstance(transformed, tio.Subject)
119
120
    def test_transforms_subject_4d(self):
121
        composed = self.get_transform(channels=('t1', 't2'), is_3d=True)
122
        subject = self.make_multichannel(self.sample_subject)
123
        subject = self.flip_affine_x(subject)
124
        transformed = None
125
        for transform in composed.transforms:
126
            repr(transform)  # cover __repr__
127
            transformed = transform(subject)
128
            trsf_channels = len(transformed.t1.data)
129
            assert trsf_channels > 1, f'Lost channels in {transform.name}'
130
            exclude = (
131
                'RandomLabelsToImage',
132
                'RemapLabels',
133
                'RemoveLabels',
134
                'SequentialLabels',
135
            )
136
            if transform.name not in exclude:
137
                self.assertEqual(
138
                    subject.shape[0],
139
                    transformed.shape[0],
140
                    f'Different number of channels after {transform.name}'
141
                )
142
                self.assertTensorNotEqual(
143
                    subject.t1.data[1],
144
                    transformed.t1.data[1],
145
                    f'No changes after {transform.name}'
146
                )
147
            subject = transformed
148
        self.assertIsInstance(transformed, tio.Subject)
149
150
    def test_transform_noop(self):
151
        transform = tio.RandomMotion(p=0)
152
        transformed = transform(self.sample_subject)
153
        self.assertIs(transformed, self.sample_subject)
154
        tensor = torch.rand(2, 4, 5, 8).numpy()
155
        transformed = transform(tensor)
156
        self.assertIs(transformed, tensor)
157
158
    def test_original_unchanged(self):
159
        subject = copy.deepcopy(self.sample_subject)
160
        composed = self.get_transform(channels=('t1', 't2'), is_3d=True)
161
        subject = self.flip_affine_x(subject)
162
        for transform in composed.transforms:
163
            original_data = copy.deepcopy(subject.t1.data)
164
            transform(subject)
165
            self.assertTensorEqual(
166
                subject.t1.data,
167
                original_data,
168
                f'Changes after {transform.name}'
169
            )
170
171
    def test_transforms_use_include(self):
172
        original_subject = copy.deepcopy(self.sample_subject)
173
        transform = tio.RandomNoise(include=['t1'])
174
        transformed = transform(self.sample_subject)
175
176
        self.assertTensorNotEqual(
177
            original_subject.t1.data,
178
            transformed.t1.data,
179
            f'Changes after {transform.name}'
180
        )
181
182
        self.assertTensorEqual(
183
            original_subject.t2.data,
184
            transformed.t2.data,
185
            f'Changes after {transform.name}'
186
        )
187
188
    def test_transforms_use_exclude(self):
189
        original_subject = copy.deepcopy(self.sample_subject)
190
        transform = tio.RandomNoise(exclude=['t2'])
191
        transformed = transform(self.sample_subject)
192
193
        self.assertTensorNotEqual(
194
            original_subject.t1.data,
195
            transformed.t1.data,
196
            f'Changes after {transform.name}'
197
        )
198
199
        self.assertTensorEqual(
200
            original_subject.t2.data,
201
            transformed.t2.data,
202
            f'Changes after {transform.name}'
203
        )
204
205
    def test_transforms_use_include_and_exclude(self):
206
        with self.assertRaises(ValueError):
207
            tio.RandomNoise(include=['t2'], exclude=['t1'])
208
209
    def test_keys_deprecated(self):
210
        with self.assertWarns(UserWarning):
211
            tio.RandomNoise(keys=['t2'])
212
213
    def test_keep_original(self):
214
        subject = copy.deepcopy(self.sample_subject)
215
        old, new = 't1', 't1_original'
216
        transformed = tio.RandomAffine(keep={old: new})(subject)
217
        assert old in transformed
218
        assert new in transformed
219
        self.assertTensorEqual(
220
            transformed[new].data,
221
            subject[old].data,
222
        )
223
        self.assertTensorNotEqual(
224
            transformed[new].data,
225
            transformed[old].data,
226
        )
227
228
229
class TestTransform(TorchioTestCase):
230
231
    def test_abstract_transform(self):
232
        with self.assertRaises(TypeError):
233
            tio.Transform()
234
235
    def test_arguments_are_not_dict(self):
236
        transform = tio.Noise(0, 1, 0)
237
        assert not transform.arguments_are_dict()
238
239
    def test_arguments_are_dict(self):
240
        transform = tio.Noise({'im': 0}, {'im': 1}, {'im': 0})
241
        assert transform.arguments_are_dict()
242
243
    def test_arguments_are_and_are_not_dict(self):
244
        transform = tio.Noise(0, {'im': 1}, {'im': 0})
245
        with self.assertRaises(ValueError):
246
            transform.arguments_are_dict()
247
248
    def test_bad_over_max(self):
249
        transform = tio.RandomNoise()
250
        with self.assertRaises(ValueError):
251
            transform._parse_range(2, 'name', max_constraint=1)
252
253
    def test_bad_over_max_range(self):
254
        transform = tio.RandomNoise()
255
        with self.assertRaises(ValueError):
256
            transform._parse_range((0, 2), 'name', max_constraint=1)
257
258
    def test_bad_type(self):
259
        transform = tio.RandomNoise()
260
        with self.assertRaises(ValueError):
261
            transform._parse_range(2.5, 'name', type_constraint=int)
262
263
    def test_no_numbers(self):
264
        transform = tio.RandomNoise()
265
        with self.assertRaises(ValueError):
266
            transform._parse_range('j', 'name')
267
268
    def test_apply_transform_missing(self):
269
        class T(tio.Transform):
270
            pass
271
        with self.assertRaises(TypeError):
272
            T().apply_transform(0)
273
274
    def test_non_invertible(self):
275
        transform = tio.RandomBlur()
276
        with self.assertRaises(RuntimeError):
277
            transform.inverse()
278
279
    def test_bad_bounds_mask(self):
280
        transform = tio.ZNormalization(masking_method='test')
281
        with self.assertRaises(ValueError):
282
            transform(self.sample_subject)
283
284
    def test_bounds_mask(self):
285
        transform = tio.ZNormalization()
286
        with self.assertRaises(ValueError):
287
            transform.get_mask_from_anatomical_label('test', 0)
288
        tensor = torch.rand((1, 2, 2, 2))
289
290
        def get_mask(label):
291
            mask = transform.get_mask_from_anatomical_label(label, tensor)
292
            return mask
293
294
        left = get_mask('Left')
295
        assert left[:, 0].sum() == 4 and left[:, 1].sum() == 0
296
        right = get_mask('Right')
297
        assert right[:, 1].sum() == 4 and right[:, 0].sum() == 0
298
        posterior = get_mask('Posterior')
299
        assert posterior[:, :, 0].sum() == 4 and posterior[:, :, 1].sum() == 0
300
        anterior = get_mask('Anterior')
301
        assert anterior[:, :, 1].sum() == 4 and anterior[:, :, 0].sum() == 0
302
        inferior = get_mask('Inferior')
303
        assert inferior[..., 0].sum() == 4 and inferior[..., 1].sum() == 0
304
        superior = get_mask('Superior')
305
        assert superior[..., 1].sum() == 4 and superior[..., 0].sum() == 0
306
307
        mask = transform.get_mask_from_bounds(3 * (0, 1), tensor)
308
        assert mask[0, 0, 0, 0] == 1
309
        assert mask.sum() == 1
310