Passed
Pull Request — master (#380)
by Fernando
01:10
created

tests.utils.TorchioTestCase.get_image_path()   C

Complexity

Conditions 10

Size

Total Lines 36
Code Lines 34

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 10
eloc 34
nop 9
dl 0
loc 36
rs 5.9999
c 0
b 0
f 0

How to fix   Complexity    Many Parameters   

Complexity

Complex classes like tests.utils.TorchioTestCase.get_image_path() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import copy
2
import shutil
3
import random
4
import tempfile
5
import unittest
6
from pathlib import Path
7
from random import shuffle
8
9
import torch
10
import numpy as np
11
from numpy.testing import assert_array_equal, assert_array_almost_equal
12
import torchio as tio
13
14
15
class TorchioTestCase(unittest.TestCase):
16
17
    def setUp(self):
18
        """Set up test fixtures, if any."""
19
        self.dir = Path(tempfile.gettempdir()) / '.torchio_tests'
20
        self.dir.mkdir(exist_ok=True)
21
        random.seed(42)
22
        np.random.seed(42)
23
24
        registration_matrix = np.array([
25
            [1, 0, 0, 10],
26
            [0, 1, 0, 0],
27
            [0, 0, 1.2, 0],
28
            [0, 0, 0, 1]
29
        ])
30
31
        subject_a = tio.Subject(
32
            t1=tio.ScalarImage(self.get_image_path('t1_a')),
33
        )
34
        subject_b = tio.Subject(
35
            t1=tio.ScalarImage(self.get_image_path('t1_b')),
36
            label=tio.LabelMap(self.get_image_path('label_b', binary=True)),
37
        )
38
        subject_c = tio.Subject(
39
            label=tio.LabelMap(self.get_image_path('label_c', binary=True)),
40
        )
41
        subject_d = tio.Subject(
42
            t1=tio.ScalarImage(
43
                self.get_image_path('t1_d'),
44
                pre_affine=registration_matrix,
45
            ),
46
            t2=tio.ScalarImage(self.get_image_path('t2_d')),
47
            label=tio.LabelMap(self.get_image_path('label_d', binary=True)),
48
        )
49
        subject_a4 = tio.Subject(
50
            t1=tio.ScalarImage(self.get_image_path('t1_a'), components=2),
51
        )
52
        self.subjects_list = [
53
            subject_a,
54
            subject_a4,
55
            subject_b,
56
            subject_c,
57
            subject_d,
58
        ]
59
        self.dataset = tio.SubjectsDataset(self.subjects_list)
60
        self.sample_subject = self.dataset[-1]  # subject_d
61
62
    def make_2d(self, subject):
63
        subject = copy.deepcopy(subject)
64
        for image in subject.get_images(intensity_only=False):
65
            image.data = image.data[..., :1]
66
        return subject
67
68
    def make_multichannel(self, subject):
69
        subject = copy.deepcopy(subject)
70
        for image in subject.get_images(intensity_only=False):
71
            image.data = torch.cat(4 * (image.data,))
72
        return subject
73
74
    def flip_affine_x(self, subject):
75
        subject = copy.deepcopy(subject)
76
        for image in subject.get_images(intensity_only=False):
77
            image.affine = np.diag((-1, 1, 1, 1)) @ image.affine
78
        return subject
79
80
    def get_inconsistent_shape_subject(self):
81
        """Return a subject containing images of different shape."""
82
        subject = tio.Subject(
83
            t1=tio.ScalarImage(self.get_image_path('t1_inc')),
84
            t2=tio.ScalarImage(
85
                self.get_image_path('t2_inc', shape=(10, 20, 31))),
86
            label=tio.LabelMap(
87
                self.get_image_path(
88
                    'label_inc',
89
                    shape=(8, 17, 25),
90
                    binary=True,
91
                ),
92
            ),
93
            label2=tio.LabelMap(
94
                self.get_image_path(
95
                    'label2_inc',
96
                    shape=(18, 17, 25),
97
                    binary=True,
98
                ),
99
            ),
100
        )
101
        return subject
102
103
    def get_reference_image_and_path(self):
104
        """Return a reference image and its path"""
105
        path = self.get_image_path('ref', shape=(10, 20, 31), spacing=(1, 1, 2))
106
        image = tio.ScalarImage(path)
107
        return image, path
108
109
    def get_subject_with_partial_volume_label_map(self, components=1):
110
        """Return a subject with a partial-volume label map."""
111
        return tio.Subject(
112
            t1=tio.ScalarImage(
113
                self.get_image_path('t1_d'),
114
            ),
115
            label=tio.LabelMap(
116
                self.get_image_path(
117
                    'label_d2', binary=False, components=components
118
                )
119
            ),
120
        )
121
122
    def tearDown(self):
123
        """Tear down test fixtures, if any."""
124
        shutil.rmtree(self.dir)
125
126
    def get_ixi_tiny(self):
127
        root_dir = Path(tempfile.gettempdir()) / 'torchio' / 'ixi_tiny'
128
        return tio.datasets.IXITiny(root_dir, download=True)
129
130
    def get_image_path(
131
            self,
132
            stem,
133
            binary=False,
134
            shape=(10, 20, 30),
135
            spacing=(1, 1, 1),
136
            components=1,
137
            add_nans=False,
138
            suffix=None,
139
            force_binary_foreground=True,
140
            ):
141
        shape = (*shape, 1) if len(shape) == 2 else shape
142
        data = np.random.rand(components, *shape)
143
        if binary:
144
            data = (data > 0.5).astype(np.uint8)
145
            if not data.sum() and force_binary_foreground:
146
                data[..., 0] = 1
147
        elif self.flip_coin():  # cast some images
148
            data *= 100
149
            dtype = np.uint8 if self.flip_coin() else np.uint16
150
            data = data.astype(dtype)
151
        if add_nans:
152
            data[:] = np.nan
153
        affine = np.diag((*spacing, 1))
154
        if suffix is None:
155
            suffix = random.choice(('.nii.gz', '.nii', '.nrrd', '.img', '.mnc'))
156
        path = self.dir / f'{stem}{suffix}'
157
        if self.flip_coin():
158
            path = str(path)
159
        image = tio.ScalarImage(
160
            tensor=data,
161
            affine=affine,
162
            check_nans=not add_nans,
163
        )
164
        image.save(path)
165
        return path
166
167
    def flip_coin(self):
168
        return np.random.rand() > 0.5
169
170
    def get_tests_data_dir(self):
171
        return Path(__file__).parent / 'image_data'
172
173
    def assertTensorNotEqual(self, *args, **kwargs):  # noqa: N802
174
        message_kwarg = dict(msg=args[2]) if len(args) == 3 else {}
175
        with self.assertRaises(AssertionError, **message_kwarg):
176
            self.assertTensorEqual(*args, **kwargs)
177
178
    @staticmethod
179
    def assertTensorEqual(*args, **kwargs):  # noqa: N802
180
        assert_array_equal(*args, **kwargs)
181
182
    @staticmethod
183
    def assertTensorAlmostEqual(*args, **kwargs):  # noqa: N802
184
        assert_array_almost_equal(*args, **kwargs)
185
186
    def get_large_composed_transform(self):
187
        all_classes = get_all_random_transforms()
188
        shuffle(all_classes)
189
        transforms = [t() for t in all_classes]
190
        # Hack as default patch size for RandomSwap is 15 and sample_subject
191
        # is (10, 20, 30)
192
        for tr in transforms:
193
            if tr.name == 'RandomSwap':
194
                tr.patch_size = np.array((10, 10, 10))
195
        return tio.Compose(transforms)
196
197
198
def get_all_random_transforms():
199
    transforms_names = [
200
        name
201
        for name in dir(tio.transforms)
202
        if name.startswith('Random')
203
    ]
204
    classes = [getattr(tio.transforms, name) for name in transforms_names]
205
    return classes
206