Passed
Push — master ( b9ac52...6aebda )
by Fernando
10:37 queued 20s
created

tests.utils.TorchioTestCase.get_image_path()   D

Complexity

Conditions 13

Size

Total Lines 45
Code Lines 43

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 13
eloc 43
nop 10
dl 0
loc 45
rs 4.2
c 0
b 0
f 0

How to fix   Complexity    Many Parameters   

Complexity

Complex classes like tests.utils.TorchioTestCase.get_image_path() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import copy
2
import shutil
3
import random
4
import tempfile
5
import unittest
6
from pathlib import Path
7
from random import shuffle
8
9
import torch
10
import numpy as np
11
from numpy.testing import assert_array_equal, assert_array_almost_equal
12
import torchio as tio
13
14
15
class TorchioTestCase(unittest.TestCase):
16
17
    def setUp(self):
18
        """Set up test fixtures, if any."""
19
        self.dir = Path(tempfile.gettempdir()) / '.torchio_tests'
20
        self.dir.mkdir(exist_ok=True)
21
        random.seed(42)
22
        np.random.seed(42)
23
24
        registration_matrix = np.array([
25
            [1, 0, 0, 10],
26
            [0, 1, 0, 0],
27
            [0, 0, 1.2, 0],
28
            [0, 0, 0, 1]
29
        ])
30
31
        subject_a = tio.Subject(
32
            t1=tio.ScalarImage(self.get_image_path('t1_a')),
33
        )
34
        subject_b = tio.Subject(
35
            t1=tio.ScalarImage(self.get_image_path('t1_b')),
36
            label=tio.LabelMap(self.get_image_path('label_b', binary=True)),
37
        )
38
        subject_c = tio.Subject(
39
            label=tio.LabelMap(self.get_image_path('label_c', binary=True)),
40
        )
41
        subject_d = tio.Subject(
42
            t1=tio.ScalarImage(
43
                self.get_image_path('t1_d'),
44
                pre_affine=registration_matrix,
45
            ),
46
            t2=tio.ScalarImage(self.get_image_path('t2_d')),
47
            label=tio.LabelMap(self.get_image_path('label_d', binary=True)),
48
        )
49
        subject_a4 = tio.Subject(
50
            t1=tio.ScalarImage(self.get_image_path('t1_a'), components=2),
51
        )
52
        self.subjects_list = [
53
            subject_a,
54
            subject_a4,
55
            subject_b,
56
            subject_c,
57
            subject_d,
58
        ]
59
        self.dataset = tio.SubjectsDataset(self.subjects_list)
60
        self.sample_subject = self.dataset[-1]  # subject_d
61
62
    def make_2d(self, subject):
63
        subject = copy.deepcopy(subject)
64
        for image in subject.get_images(intensity_only=False):
65
            image.data = image.data[..., :1]
66
        return subject
67
68
    def make_multichannel(self, subject):
69
        subject = copy.deepcopy(subject)
70
        for image in subject.get_images(intensity_only=False):
71
            image.data = torch.cat(4 * (image.data,))
72
        return subject
73
74
    def flip_affine_x(self, subject):
75
        subject = copy.deepcopy(subject)
76
        for image in subject.get_images(intensity_only=False):
77
            image.affine = np.diag((-1, 1, 1, 1)) @ image.affine
78
        return subject
79
80
    def get_inconsistent_shape_subject(self):
81
        """Return a subject containing images of different shape."""
82
        subject = tio.Subject(
83
            t1=tio.ScalarImage(self.get_image_path('t1_inc')),
84
            t2=tio.ScalarImage(
85
                self.get_image_path('t2_inc', shape=(10, 20, 31))),
86
            label=tio.LabelMap(
87
                self.get_image_path(
88
                    'label_inc',
89
                    shape=(8, 17, 25),
90
                    binary=True,
91
                ),
92
            ),
93
            label2=tio.LabelMap(
94
                self.get_image_path(
95
                    'label2_inc',
96
                    shape=(18, 17, 25),
97
                    binary=True,
98
                ),
99
            ),
100
        )
101
        return subject
102
103
    def get_reference_image_and_path(self):
104
        """Return a reference image and its path"""
105
        path = self.get_image_path('ref', shape=(10, 20, 31), spacing=(1, 1, 2))
106
        image = tio.ScalarImage(path)
107
        return image, path
108
109
    def get_subject_with_partial_volume_label_map(self, components=1):
110
        """Return a subject with a partial-volume label map."""
111
        return tio.Subject(
112
            t1=tio.ScalarImage(
113
                self.get_image_path('t1_d'),
114
            ),
115
            label=tio.LabelMap(
116
                self.get_image_path(
117
                    'label_d2', binary=False, components=components
118
                )
119
            ),
120
        )
121
122
    def get_subject_with_labels(self, labels):
123
        return tio.Subject(
124
            label=tio.LabelMap(
125
                self.get_image_path(
126
                    'label_multi', labels=labels
127
                )
128
            )
129
        )
130
131
    def get_unique_labels(self, label_map):
132
        labels = torch.unique(label_map.data)
133
        labels = {i.item() for i in labels if i != 0}
134
        return labels
135
136
    def tearDown(self):
137
        """Tear down test fixtures, if any."""
138
        shutil.rmtree(self.dir)
139
140
    def get_ixi_tiny(self):
141
        root_dir = Path(tempfile.gettempdir()) / 'torchio' / 'ixi_tiny'
142
        return tio.datasets.IXITiny(root_dir, download=True)
143
144
    def get_image_path(
145
            self,
146
            stem,
147
            binary=False,
148
            labels=None,
149
            shape=(10, 20, 30),
150
            spacing=(1, 1, 1),
151
            components=1,
152
            add_nans=False,
153
            suffix=None,
154
            force_binary_foreground=True,
155
            ):
156
        shape = (*shape, 1) if len(shape) == 2 else shape
157
        data = np.random.rand(components, *shape)
158
        if binary:
159
            data = (data > 0.5).astype(np.uint8)
160
            if not data.sum() and force_binary_foreground:
161
                data[..., 0] = 1
162
        elif labels is not None:
163
            data = (data * (len(labels) + 1)).astype(np.uint8)
164
            new_data = np.zeros_like(data)
165
            for i, label in enumerate(labels):
166
                new_data[data == (i + 1)] = label
167
                if not (new_data == label).sum():
168
                    new_data[..., i] = label
169
            data = new_data
170
        elif self.flip_coin():  # cast some images
171
            data *= 100
172
            dtype = np.uint8 if self.flip_coin() else np.uint16
173
            data = data.astype(dtype)
174
        if add_nans:
175
            data[:] = np.nan
176
        affine = np.diag((*spacing, 1))
177
        if suffix is None:
178
            suffix = random.choice(('.nii.gz', '.nii', '.nrrd', '.img', '.mnc'))
179
        path = self.dir / f'{stem}{suffix}'
180
        if self.flip_coin():
181
            path = str(path)
182
        image = tio.ScalarImage(
183
            tensor=data,
184
            affine=affine,
185
            check_nans=not add_nans,
186
        )
187
        image.save(path)
188
        return path
189
190
    def flip_coin(self):
191
        return np.random.rand() > 0.5
192
193
    def get_tests_data_dir(self):
194
        return Path(__file__).parent / 'image_data'
195
196
    def assertTensorNotEqual(self, *args, **kwargs):  # noqa: N802
197
        message_kwarg = {'msg': args[2]} if len(args) == 3 else {}
198
        with self.assertRaises(AssertionError, **message_kwarg):
199
            self.assertTensorEqual(*args, **kwargs)
200
201
    @staticmethod
202
    def assertTensorEqual(*args, **kwargs):  # noqa: N802
203
        assert_array_equal(*args, **kwargs)
204
205
    @staticmethod
206
    def assertTensorAlmostEqual(*args, **kwargs):  # noqa: N802
207
        assert_array_almost_equal(*args, **kwargs)
208
209
    def get_large_composed_transform(self):
210
        all_classes = get_all_random_transforms()
211
        shuffle(all_classes)
212
        transforms = [t() for t in all_classes]
213
        # Hack as default patch size for RandomSwap is 15 and sample_subject
214
        # is (10, 20, 30)
215
        for tr in transforms:
216
            if tr.name == 'RandomSwap':
217
                tr.patch_size = np.array((10, 10, 10))
218
        return tio.Compose(transforms)
219
220
221
def get_all_random_transforms():
222
    transforms_names = [
223
        name
224
        for name in dir(tio.transforms)
225
        if name.startswith('Random')
226
    ]
227
    classes = [getattr(tio.transforms, name) for name in transforms_names]
228
    return classes
229