tests.transforms.test_transforms   C
last analyzed

Complexity

Total Complexity 54

Size/Duplication

Total Lines 304
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 54
eloc 250
dl 0
loc 304
rs 6.4799
c 0
b 0
f 0

29 Methods

Rating   Name   Duplication   Size   Complexity  
A TestTransforms.test_transforms_dict() 0 5 1
A TestTransforms.test_transforms_subject_4d() 0 30 3
A TestTransforms.test_transforms_subject_3d() 0 4 1
A TestTransforms.test_transforms_sitk() 0 8 1
A TestTransforms.test_transform_noop() 0 7 1
A TestTransform.test_non_invertible() 0 4 2
A TestTransforms.test_transforms_tensor() 0 6 1
A TestTransforms.test_keys_deprecated() 0 3 2
A TestTransforms.test_original_unchanged() 0 11 2
A TestTransforms.test_transforms_subject_2d() 0 5 1
A TestTransform.test_bad_over_max_range() 0 4 2
A TestTransforms.test_transforms_use_include() 0 15 1
A TestTransform.test_abstract_transform() 0 3 2
A TestTransform.test_bad_over_max() 0 4 2
A TestTransform.test_bad_bounds_mask() 0 4 2
A TestTransforms.test_transforms_image() 0 5 1
A TestTransform.test_bad_type() 0 4 2
A TestTransform.test_arguments_are_dict() 0 3 1
A TestTransforms.test_transforms_dict_no_keys() 0 5 2
A TestTransforms.test_keep_original() 0 13 1
A TestTransforms.test_transforms_use_include_and_exclude() 0 3 2
A TestTransform.test_no_numbers() 0 4 2
A TestTransform.test_arguments_are_and_are_not_dict() 0 4 2
A TestTransform.test_apply_transform_missing() 0 5 2
A TestTransforms.test_transforms_use_exclude() 0 15 1
A TestTransforms.test_transforms_array() 0 6 1
C TestTransforms.get_transform() 0 48 10
A TestTransform.test_arguments_are_not_dict() 0 3 1
A TestTransform.test_bounds_mask() 0 26 2

How to fix   Complexity   

Complexity

Complex classes like tests.transforms.test_transforms often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import copy
2
import torch
3
import numpy as np
4
import torchio as tio
5
import SimpleITK as sitk
6
from ..utils import TorchioTestCase
7
8
9
class TestTransforms(TorchioTestCase):
10
    """Tests for all transforms."""
11
12
    def get_transform(self, channels, is_3d=True, labels=True):
13
        landmarks_dict = {
14
            channel: np.linspace(0, 100, 13) for channel in channels
15
        }
16
        disp = 1 if is_3d else (1, 1, 0.01)
17
        elastic = tio.RandomElasticDeformation(max_displacement=disp)
18
        cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
19
        resize_args = (10, 20, 30) if is_3d else (10, 20, 1)
20
        flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
21
        swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
22
        pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
23
        crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
24
        remapping = {1: 2, 2: 1, 3: 20, 4: 25}
25
        transforms = [
26
            tio.CropOrPad(cp_args),
27
            tio.EnsureShapeMultiple(2, method='crop'),
28
            tio.Resize(resize_args),
29
            tio.ToCanonical(),
30
            tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
31
            tio.CopyAffine(channels[0]),
32
            tio.Resample((1, 1.1, 1.25)),
33
            tio.RandomFlip(axes=flip_axes, flip_probability=1),
34
            tio.RandomMotion(),
35
            tio.RandomGhosting(axes=(0, 1, 2)),
36
            tio.RandomSpike(),
37
            tio.RandomNoise(),
38
            tio.RandomBlur(),
39
            tio.RandomSwap(patch_size=swap_patch, num_iterations=5),
40
            tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY),
41
            tio.RandomBiasField(),
42
            tio.RescaleIntensity(out_min_max=(0, 1)),
43
            tio.ZNormalization(),
44
            tio.HistogramStandardization(landmarks_dict),
45
            elastic,
46
            tio.RandomAffine(),
47
            tio.OneOf({
48
                tio.RandomAffine(): 3,
49
                elastic: 1,
50
            }),
51
            tio.RemapLabels(remapping=remapping, masking_method='Left'),
52
            tio.RemoveLabels([1, 3]),
53
            tio.SequentialLabels(),
54
            tio.Pad(pad_args, padding_mode=3),
55
            tio.Crop(crop_args),
56
        ]
57
        if labels:
58
            transforms.append(tio.RandomLabelsToImage(label_key='label'))
59
        return tio.Compose(transforms)
60
61
    def test_transforms_dict(self):
62
        transform = tio.RandomNoise(include=('t1', 't2'))
63
        input_dict = {k: v.data for (k, v) in self.sample_subject.items()}
64
        transformed = transform(input_dict)
65
        self.assertIsInstance(transformed, dict)
66
67
    def test_transforms_dict_no_keys(self):
68
        transform = tio.RandomNoise()
69
        input_dict = {k: v.data for (k, v) in self.sample_subject.items()}
70
        with self.assertRaises(RuntimeError):
71
            transform(input_dict)
72
73
    def test_transforms_image(self):
74
        transform = self.get_transform(
75
            channels=('default_image_name',), labels=False)
76
        transformed = transform(self.sample_subject.t1)
77
        self.assertIsInstance(transformed, tio.ScalarImage)
78
79
    def test_transforms_tensor(self):
80
        tensor = torch.rand(2, 4, 5, 8)
81
        transform = self.get_transform(
82
            channels=('default_image_name',), labels=False)
83
        transformed = transform(tensor)
84
        self.assertIsInstance(transformed, torch.Tensor)
85
86
    def test_transforms_array(self):
87
        tensor = torch.rand(2, 4, 5, 8).numpy()
88
        transform = self.get_transform(
89
            channels=('default_image_name',), labels=False)
90
        transformed = transform(tensor)
91
        self.assertIsInstance(transformed, np.ndarray)
92
93
    def test_transforms_sitk(self):
94
        tensor = torch.rand(2, 4, 5, 8)
95
        affine = np.diag((-1, 2, -3, 1))
96
        image = tio.data.io.nib_to_sitk(tensor, affine)
97
        transform = self.get_transform(
98
            channels=('default_image_name',), labels=False)
99
        transformed = transform(image)
100
        self.assertIsInstance(transformed, sitk.Image)
101
102
    def test_transforms_subject_3d(self):
103
        transform = self.get_transform(channels=('t1', 't2'), is_3d=True)
104
        transformed = transform(self.sample_subject)
105
        self.assertIsInstance(transformed, tio.Subject)
106
107
    def test_transforms_subject_2d(self):
108
        transform = self.get_transform(channels=('t1', 't2'), is_3d=False)
109
        subject = self.make_2d(self.sample_subject)
110
        transformed = transform(subject)
111
        self.assertIsInstance(transformed, tio.Subject)
112
113
    def test_transforms_subject_4d(self):
114
        composed = self.get_transform(channels=('t1', 't2'), is_3d=True)
115
        subject = self.make_multichannel(self.sample_subject)
116
        subject = self.flip_affine_x(subject)
117
        transformed = None
118
        for transform in composed.transforms:
119
            repr(transform)  # cover __repr__
120
            transformed = transform(subject)
121
            trsf_channels = len(transformed.t1.data)
122
            assert trsf_channels > 1, f'Lost channels in {transform.name}'
123
            exclude = (
124
                'RandomLabelsToImage',
125
                'RemapLabels',
126
                'RemoveLabels',
127
                'SequentialLabels',
128
                'CopyAffine',
129
            )
130
            if transform.name not in exclude:
131
                self.assertEqual(
132
                    subject.shape[0],
133
                    transformed.shape[0],
134
                    f'Different number of channels after {transform.name}'
135
                )
136
                self.assertTensorNotEqual(
137
                    subject.t1.data[1],
138
                    transformed.t1.data[1],
139
                    f'No changes after {transform.name}'
140
                )
141
            subject = transformed
142
        self.assertIsInstance(transformed, tio.Subject)
143
144
    def test_transform_noop(self):
145
        transform = tio.RandomMotion(p=0)
146
        transformed = transform(self.sample_subject)
147
        self.assertIs(transformed, self.sample_subject)
148
        tensor = torch.rand(2, 4, 5, 8).numpy()
149
        transformed = transform(tensor)
150
        self.assertIs(transformed, tensor)
151
152
    def test_original_unchanged(self):
153
        subject = copy.deepcopy(self.sample_subject)
154
        composed = self.get_transform(channels=('t1', 't2'), is_3d=True)
155
        subject = self.flip_affine_x(subject)
156
        for transform in composed.transforms:
157
            original_data = copy.deepcopy(subject.t1.data)
158
            transform(subject)
159
            self.assertTensorEqual(
160
                subject.t1.data,
161
                original_data,
162
                f'Changes after {transform.name}'
163
            )
164
165
    def test_transforms_use_include(self):
166
        original_subject = copy.deepcopy(self.sample_subject)
167
        transform = tio.RandomNoise(include=['t1'])
168
        transformed = transform(self.sample_subject)
169
170
        self.assertTensorNotEqual(
171
            original_subject.t1.data,
172
            transformed.t1.data,
173
            f'Changes after {transform.name}'
174
        )
175
176
        self.assertTensorEqual(
177
            original_subject.t2.data,
178
            transformed.t2.data,
179
            f'Changes after {transform.name}'
180
        )
181
182
    def test_transforms_use_exclude(self):
183
        original_subject = copy.deepcopy(self.sample_subject)
184
        transform = tio.RandomNoise(exclude=['t2'])
185
        transformed = transform(self.sample_subject)
186
187
        self.assertTensorNotEqual(
188
            original_subject.t1.data,
189
            transformed.t1.data,
190
            f'Changes after {transform.name}'
191
        )
192
193
        self.assertTensorEqual(
194
            original_subject.t2.data,
195
            transformed.t2.data,
196
            f'Changes after {transform.name}'
197
        )
198
199
    def test_transforms_use_include_and_exclude(self):
200
        with self.assertRaises(ValueError):
201
            tio.RandomNoise(include=['t2'], exclude=['t1'])
202
203
    def test_keys_deprecated(self):
204
        with self.assertWarns(UserWarning):
205
            tio.RandomNoise(keys=['t2'])
206
207
    def test_keep_original(self):
208
        subject = copy.deepcopy(self.sample_subject)
209
        old, new = 't1', 't1_original'
210
        transformed = tio.RandomAffine(keep={old: new})(subject)
211
        assert old in transformed
212
        assert new in transformed
213
        self.assertTensorEqual(
214
            transformed[new].data,
215
            subject[old].data,
216
        )
217
        self.assertTensorNotEqual(
218
            transformed[new].data,
219
            transformed[old].data,
220
        )
221
222
223
class TestTransform(TorchioTestCase):
224
225
    def test_abstract_transform(self):
226
        with self.assertRaises(TypeError):
227
            tio.Transform()
228
229
    def test_arguments_are_not_dict(self):
230
        transform = tio.Noise(0, 1, 0)
231
        assert not transform.arguments_are_dict()
232
233
    def test_arguments_are_dict(self):
234
        transform = tio.Noise({'im': 0}, {'im': 1}, {'im': 0})
235
        assert transform.arguments_are_dict()
236
237
    def test_arguments_are_and_are_not_dict(self):
238
        transform = tio.Noise(0, {'im': 1}, {'im': 0})
239
        with self.assertRaises(ValueError):
240
            transform.arguments_are_dict()
241
242
    def test_bad_over_max(self):
243
        transform = tio.RandomNoise()
244
        with self.assertRaises(ValueError):
245
            transform._parse_range(2, 'name', max_constraint=1)
246
247
    def test_bad_over_max_range(self):
248
        transform = tio.RandomNoise()
249
        with self.assertRaises(ValueError):
250
            transform._parse_range((0, 2), 'name', max_constraint=1)
251
252
    def test_bad_type(self):
253
        transform = tio.RandomNoise()
254
        with self.assertRaises(ValueError):
255
            transform._parse_range(2.5, 'name', type_constraint=int)
256
257
    def test_no_numbers(self):
258
        transform = tio.RandomNoise()
259
        with self.assertRaises(ValueError):
260
            transform._parse_range('j', 'name')
261
262
    def test_apply_transform_missing(self):
263
        class T(tio.Transform):
264
            pass
265
        with self.assertRaises(TypeError):
266
            T().apply_transform(0)
267
268
    def test_non_invertible(self):
269
        transform = tio.RandomBlur()
270
        with self.assertRaises(RuntimeError):
271
            transform.inverse()
272
273
    def test_bad_bounds_mask(self):
274
        transform = tio.ZNormalization(masking_method='test')
275
        with self.assertRaises(ValueError):
276
            transform(self.sample_subject)
277
278
    def test_bounds_mask(self):
279
        transform = tio.ZNormalization()
280
        with self.assertRaises(ValueError):
281
            transform.get_mask_from_anatomical_label('test', 0)
282
        tensor = torch.rand((1, 2, 2, 2))
283
284
        def get_mask(label):
285
            mask = transform.get_mask_from_anatomical_label(label, tensor)
286
            return mask
287
288
        left = get_mask('Left')
289
        assert left[:, 0].sum() == 4 and left[:, 1].sum() == 0
290
        right = get_mask('Right')
291
        assert right[:, 1].sum() == 4 and right[:, 0].sum() == 0
292
        posterior = get_mask('Posterior')
293
        assert posterior[:, :, 0].sum() == 4 and posterior[:, :, 1].sum() == 0
294
        anterior = get_mask('Anterior')
295
        assert anterior[:, :, 1].sum() == 4 and anterior[:, :, 0].sum() == 0
296
        inferior = get_mask('Inferior')
297
        assert inferior[..., 0].sum() == 4 and inferior[..., 1].sum() == 0
298
        superior = get_mask('Superior')
299
        assert superior[..., 1].sum() == 4 and superior[..., 0].sum() == 0
300
301
        mask = transform.get_mask_from_bounds(3 * (0, 1), tensor)
302
        assert mask[0, 0, 0, 0] == 1
303
        assert mask.sum() == 1
304