tests.utils.TorchioTestCase.get_image_path()   D
last analyzed

Complexity

Conditions 13

Size

Total Lines 46
Code Lines 44

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 13
eloc 44
nop 10
dl 0
loc 46
rs 4.2
c 0
b 0
f 0

How to fix   Complexity    Many Parameters   

Complexity

Complex classes like tests.utils.TorchioTestCase.get_image_path() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import os
2
import copy
3
import shutil
4
import random
5
import tempfile
6
import unittest
7
from pathlib import Path
8
from random import shuffle
9
10
import torch
11
import numpy as np
12
from numpy.testing import assert_array_equal, assert_array_almost_equal
13
import torchio as tio
14
15
16
class TorchioTestCase(unittest.TestCase):
17
18
    def setUp(self):
19
        """Set up test fixtures, if any."""
20
        self.dir = Path(tempfile.gettempdir()) / os.urandom(24).hex()
21
        self.dir.mkdir(exist_ok=True)
22
        random.seed(42)
23
        np.random.seed(42)
24
25
        registration_matrix = np.array([
26
            [1, 0, 0, 10],
27
            [0, 1, 0, 0],
28
            [0, 0, 1.2, 0],
29
            [0, 0, 0, 1]
30
        ])
31
32
        subject_a = tio.Subject(
33
            t1=tio.ScalarImage(self.get_image_path('t1_a')),
34
        )
35
        subject_b = tio.Subject(
36
            t1=tio.ScalarImage(self.get_image_path('t1_b')),
37
            label=tio.LabelMap(self.get_image_path('label_b', binary=True)),
38
        )
39
        subject_c = tio.Subject(
40
            label=tio.LabelMap(self.get_image_path('label_c', binary=True)),
41
        )
42
        subject_d = tio.Subject(
43
            t1=tio.ScalarImage(
44
                self.get_image_path('t1_d'),
45
                pre_affine=registration_matrix,
46
            ),
47
            t2=tio.ScalarImage(self.get_image_path('t2_d')),
48
            label=tio.LabelMap(self.get_image_path('label_d', binary=True)),
49
        )
50
        subject_a4 = tio.Subject(
51
            t1=tio.ScalarImage(self.get_image_path('t1_a'), components=2),
52
        )
53
        self.subjects_list = [
54
            subject_a,
55
            subject_a4,
56
            subject_b,
57
            subject_c,
58
            subject_d,
59
        ]
60
        self.dataset = tio.SubjectsDataset(self.subjects_list)
61
        self.sample_subject = self.dataset[-1]  # subject_d
62
63
    def make_2d(self, subject):
64
        subject = copy.deepcopy(subject)
65
        for image in subject.get_images(intensity_only=False):
66
            image.set_data(image.data[..., :1])
67
        return subject
68
69
    def make_multichannel(self, subject):
70
        subject = copy.deepcopy(subject)
71
        for image in subject.get_images(intensity_only=False):
72
            image.set_data(torch.cat(4 * (image.data,)))
73
        return subject
74
75
    def flip_affine_x(self, subject):
76
        subject = copy.deepcopy(subject)
77
        for image in subject.get_images(intensity_only=False):
78
            image.affine = np.diag((-1, 1, 1, 1)) @ image.affine
79
        return subject
80
81
    def get_inconsistent_shape_subject(self):
82
        """Return a subject containing images of different shape."""
83
        subject = tio.Subject(
84
            t1=tio.ScalarImage(self.get_image_path('t1_inc')),
85
            t2=tio.ScalarImage(
86
                self.get_image_path('t2_inc', shape=(10, 20, 31))),
87
            label=tio.LabelMap(
88
                self.get_image_path(
89
                    'label_inc',
90
                    shape=(8, 17, 25),
91
                    binary=True,
92
                ),
93
            ),
94
            label2=tio.LabelMap(
95
                self.get_image_path(
96
                    'label2_inc',
97
                    shape=(18, 17, 25),
98
                    binary=True,
99
                ),
100
            ),
101
        )
102
        return subject
103
104
    def get_reference_image_and_path(self):
105
        """Return a reference image and its path"""
106
        path = self.get_image_path(
107
            'ref',
108
            shape=(10, 20, 31),
109
            spacing=(1, 1, 2),
110
        )
111
        image = tio.ScalarImage(path)
112
        return image, path
113
114
    def get_subject_with_partial_volume_label_map(self, components=1):
115
        """Return a subject with a partial-volume label map."""
116
        return tio.Subject(
117
            t1=tio.ScalarImage(
118
                self.get_image_path('t1_d'),
119
            ),
120
            label=tio.LabelMap(
121
                self.get_image_path(
122
                    'label_d2', binary=False, components=components
123
                )
124
            ),
125
        )
126
127
    def get_subject_with_labels(self, labels):
128
        return tio.Subject(
129
            label=tio.LabelMap(
130
                self.get_image_path(
131
                    'label_multi', labels=labels
132
                )
133
            )
134
        )
135
136
    def get_unique_labels(self, label_map):
137
        labels = torch.unique(label_map.data)
138
        labels = {i.item() for i in labels if i != 0}
139
        return labels
140
141
    def tearDown(self):
142
        """Tear down test fixtures, if any."""
143
        shutil.rmtree(self.dir)
144
145
    def get_ixi_tiny(self):
146
        root_dir = Path(tempfile.gettempdir()) / 'torchio' / 'ixi_tiny'
147
        return tio.datasets.IXITiny(root_dir, download=True)
148
149
    def get_image_path(
150
            self,
151
            stem,
152
            binary=False,
153
            labels=None,
154
            shape=(10, 20, 30),
155
            spacing=(1, 1, 1),
156
            components=1,
157
            add_nans=False,
158
            suffix=None,
159
            force_binary_foreground=True,
160
            ):
161
        shape = (*shape, 1) if len(shape) == 2 else shape
162
        data = np.random.rand(components, *shape)
163
        if binary:
164
            data = (data > 0.5).astype(np.uint8)
165
            if not data.sum() and force_binary_foreground:
166
                data[..., 0] = 1
167
        elif labels is not None:
168
            data = (data * (len(labels) + 1)).astype(np.uint8)
169
            new_data = np.zeros_like(data)
170
            for i, label in enumerate(labels):
171
                new_data[data == (i + 1)] = label
172
                if not (new_data == label).sum():
173
                    new_data[..., i] = label
174
            data = new_data
175
        elif self.flip_coin():  # cast some images
176
            data *= 100
177
            dtype = np.uint8 if self.flip_coin() else np.uint16
178
            data = data.astype(dtype)
179
        if add_nans:
180
            data[:] = np.nan
181
        affine = np.diag((*spacing, 1))
182
        if suffix is None:
183
            extensions = '.nii.gz', '.nii', '.nrrd', '.img', '.mnc'
184
            suffix = random.choice(extensions)
185
        path = self.dir / f'{stem}{suffix}'
186
        if self.flip_coin():
187
            path = str(path)
188
        image = tio.ScalarImage(
189
            tensor=data,
190
            affine=affine,
191
            check_nans=not add_nans,
192
        )
193
        image.save(path)
194
        return path
195
196
    def flip_coin(self):
197
        return np.random.rand() > 0.5
198
199
    def get_tests_data_dir(self):
200
        return Path(__file__).parent / 'image_data'
201
202
    def assertTensorNotEqual(self, *args, **kwargs):  # noqa: N802
203
        message_kwarg = {'msg': args[2]} if len(args) == 3 else {}
204
        with self.assertRaises(AssertionError, **message_kwarg):
205
            self.assertTensorEqual(*args, **kwargs)
206
207
    @staticmethod
208
    def assertTensorEqual(*args, **kwargs):  # noqa: N802
209
        assert_array_equal(*args, **kwargs)
210
211
    @staticmethod
212
    def assertTensorAlmostEqual(*args, **kwargs):  # noqa: N802
213
        assert_array_almost_equal(*args, **kwargs)
214
215
    def get_large_composed_transform(self):
216
        all_classes = get_all_random_transforms()
217
        shuffle(all_classes)
218
        transforms = [t() for t in all_classes]
219
        # Hack as default patch size for RandomSwap is 15 and sample_subject
220
        # is (10, 20, 30)
221
        for tr in transforms:
222
            if tr.name == 'RandomSwap':
223
                tr.patch_size = np.array((10, 10, 10))
224
        return tio.Compose(transforms)
225
226
227
def get_all_random_transforms():
228
    transforms_names = [
229
        name
230
        for name in dir(tio.transforms)
231
        if name.startswith('Random')
232
    ]
233
    classes = [getattr(tio.transforms, name) for name in transforms_names]
234
    return classes
235