tests.utils.TorchioTestCase.get_image_path()   D
last analyzed

Complexity

Conditions 13

Size

Total Lines 46
Code Lines 44

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 13
eloc 44
nop 10
dl 0
loc 46
rs 4.2
c 0
b 0
f 0

How to fix   Complexity    Many Parameters   

Complexity

Complex classes like tests.utils.TorchioTestCase.get_image_path() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import copy
2
import os
3
import random
4
import shutil
5
import tempfile
6
import unittest
7
from collections.abc import Sequence
8
from pathlib import Path
9
from random import shuffle
10
11
import numpy as np
12
import pytest
13
import torch
14
15
import torchio as tio
16
17
18
class TorchioTestCase(unittest.TestCase):
19
    def setUp(self):
20
        """Set up test fixtures, if any."""
21
        self.dir = Path(tempfile.gettempdir()) / os.urandom(24).hex()
22
        self.dir.mkdir(exist_ok=True)
23
        random.seed(42)
24
        np.random.seed(42)
25
26
        registration_matrix = np.array(
27
            [
28
                [1, 0, 0, 10],
29
                [0, 1, 0, 0],
30
                [0, 0, 1.2, 0],
31
                [0, 0, 0, 1],
32
            ]
33
        )
34
35
        subject_a = tio.Subject(
36
            t1=tio.ScalarImage(self.get_image_path('t1_a')),
37
        )
38
        subject_b = tio.Subject(
39
            t1=tio.ScalarImage(self.get_image_path('t1_b')),
40
            label=tio.LabelMap(self.get_image_path('label_b', binary=True)),
41
        )
42
        subject_c = tio.Subject(
43
            label=tio.LabelMap(self.get_image_path('label_c', binary=True)),
44
        )
45
        subject_d = tio.Subject(
46
            t1=tio.ScalarImage(
47
                self.get_image_path('t1_d'),
48
                pre_affine=registration_matrix,
49
            ),
50
            t2=tio.ScalarImage(self.get_image_path('t2_d')),
51
            label=tio.LabelMap(self.get_image_path('label_d', binary=True)),
52
        )
53
        subject_a4 = tio.Subject(
54
            t1=tio.ScalarImage(self.get_image_path('t1_a'), components=4),
55
        )
56
        self.subjects_list = [
57
            subject_a,
58
            subject_a4,
59
            subject_b,
60
            subject_c,
61
            subject_d,
62
        ]
63
        self.dataset = tio.SubjectsDataset(self.subjects_list)
64
        self.sample_subject = self.dataset[-1]  # subject_d
65
        self.subject_4d = self.dataset[1]
66
67
    def make_2d(self, subject):
68
        subject = copy.deepcopy(subject)
69
        for image in subject.get_images(intensity_only=False):
70
            image.set_data(image.data[..., :1])
71
        return subject
72
73
    def make_multichannel(self, subject):
74
        subject = copy.deepcopy(subject)
75
        for image in subject.get_images(intensity_only=False):
76
            image.set_data(torch.cat(4 * (image.data,)))
77
        return subject
78
79
    def flip_affine_x(self, subject):
80
        subject = copy.deepcopy(subject)
81
        for image in subject.get_images(intensity_only=False):
82
            image.affine = np.diag((-1, 1, 1, 1)) @ image.affine
83
        return subject
84
85
    def get_inconsistent_shape_subject(self):
86
        """Return a subject containing images of different shape."""
87
        subject = tio.Subject(
88
            t1=tio.ScalarImage(self.get_image_path('t1_inc')),
89
            t2=tio.ScalarImage(
90
                self.get_image_path('t2_inc', shape=(10, 20, 31)),
91
            ),
92
            label=tio.LabelMap(
93
                self.get_image_path(
94
                    'label_inc',
95
                    shape=(8, 17, 25),
96
                    binary=True,
97
                ),
98
            ),
99
            label2=tio.LabelMap(
100
                self.get_image_path(
101
                    'label2_inc',
102
                    shape=(18, 17, 25),
103
                    binary=True,
104
                ),
105
            ),
106
        )
107
        return subject
108
109
    def get_reference_image_and_path(self):
110
        """Return a reference image and its path."""
111
        path = self.get_image_path(
112
            'ref',
113
            shape=(10, 20, 31),
114
            spacing=(1, 1, 2),
115
        )
116
        image = tio.ScalarImage(path)
117
        return image, path
118
119
    def get_subject_with_partial_volume_label_map(self, components=1):
120
        """Return a subject with a partial-volume label map."""
121
        return tio.Subject(
122
            t1=tio.ScalarImage(
123
                self.get_image_path('t1_d'),
124
            ),
125
            label=tio.LabelMap(
126
                self.get_image_path(
127
                    'label_d2',
128
                    binary=False,
129
                    components=components,
130
                ),
131
            ),
132
        )
133
134
    def get_subject_with_labels(self, labels):
135
        return tio.Subject(
136
            label=tio.LabelMap(
137
                self.get_image_path(
138
                    'label_multi',
139
                    labels=labels,
140
                ),
141
            ),
142
        )
143
144
    @staticmethod
145
    def get_unique_labels(data: torch.Tensor) -> set[int]:
146
        labels = data.unique().tolist()
147
        return set(labels)
148
149
    @staticmethod
150
    def get_tensor_with_labels(labels: Sequence) -> torch.Tensor:
151
        tensor = torch.as_tensor(list(labels))
152
        return tensor.repeat_interleave(2).reshape(1, 1, 1, -1)
153
154
    def tearDown(self):
155
        """Tear down test fixtures, if any."""
156
        shutil.rmtree(self.dir)
157
158
    def get_ixi_tiny(self):
159
        root_dir = Path(tempfile.gettempdir()) / 'torchio' / 'ixi_tiny'
160
        return tio.datasets.IXITiny(root_dir, download=True)
161
162
    def get_image_path(
163
        self,
164
        stem,
165
        binary=False,
166
        labels=None,
167
        shape=(10, 20, 30),
168
        spacing=(1, 1, 1),
169
        components=1,
170
        add_nans=False,
171
        suffix=None,
172
        force_binary_foreground=True,
173
    ):
174
        shape = (*shape, 1) if len(shape) == 2 else shape
175
        data = np.random.rand(components, *shape)
176
        if binary:
177
            data = (data > 0.5).astype(np.uint8)
178
            if not data.sum() and force_binary_foreground:
179
                data[..., 0] = 1
180
        elif labels is not None:
181
            data = (data * (len(labels) + 1)).astype(np.uint8)
182
            new_data = np.zeros_like(data)
183
            for i, label in enumerate(labels):
184
                new_data[data == (i + 1)] = label
185
                if not (new_data == label).sum():
186
                    new_data[..., i] = label
187
            data = new_data
188
        elif self.flip_coin():  # cast some images
189
            data *= 100
190
            dtype = np.uint8 if self.flip_coin() else np.uint16
191
            data = data.astype(dtype)
192
        if add_nans:
193
            data[:] = np.nan
194
        affine = np.diag((*spacing, 1))
195
        if suffix is None:
196
            extensions = '.nii.gz', '.nii', '.nrrd', '.img', '.mnc'
197
            suffix = random.choice(extensions)
198
        path = self.dir / f'{stem}{suffix}'
199
        if self.flip_coin():
200
            path = str(path)
201
        image = tio.ScalarImage(
202
            tensor=data,
203
            affine=affine,
204
            check_nans=not add_nans,
205
        )
206
        image.save(path)
207
        return path
208
209
    def flip_coin(self):
210
        return np.random.rand() > 0.5
211
212
    def get_tests_data_dir(self):
213
        return Path(__file__).parent / 'image_data'
214
215
    def assert_tensor_not_equal(self, *args, **kwargs):  # noqa: N802
216
        with pytest.raises(AssertionError):
217
            self.assert_tensor_equal(*args, **kwargs)
218
219
    @staticmethod
220
    def assert_tensor_equal(*args, **kwargs):  # noqa: N802
221
        torch.testing.assert_close(
222
            *args,
223
            rtol=0,
224
            atol=0,
225
            check_dtype=False,
226
            **kwargs,
227
        )
228
229
    @staticmethod
230
    def assert_tensor_almost_equal(*args, **kwargs):  # noqa: N802
231
        torch.testing.assert_close(
232
            *args,
233
            **kwargs,
234
            check_dtype=False,
235
        )
236
237
    @staticmethod
238
    def assert_tensor_all_zeros(tensor):  # noqa: N802
239
        assert torch.all(tensor == 0)
240
241
    def get_large_composed_transform(self):
242
        all_classes = get_all_random_transforms()
243
        shuffle(all_classes)
244
        transforms = [t() for t in all_classes]
245
        # Hack as default patch size for RandomSwap is 15 and sample_subject
246
        # is (10, 20, 30)
247
        for tr in transforms:
248
            if tr.name == 'RandomSwap':
249
                tr.patch_size = np.array((10, 10, 10))
250
        return tio.Compose(transforms)
251
252
253
def get_all_random_transforms():
254
    transforms_names = [
255
        name for name in dir(tio.transforms) if name.startswith('Random')
256
    ]
257
    classes = [getattr(tio.transforms, name) for name in transforms_names]
258
    return classes
259