tests.utils   A
last analyzed

Complexity

Total Complexity 39

Size/Duplication

Total Lines 259
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 198
dl 0
loc 259
rs 9.28
c 0
b 0
f 0
wmc 39

1 Function

Rating   Name   Duplication   Size   Complexity  
A get_all_random_transforms() 0 6 1

20 Methods

Rating   Name   Duplication   Size   Complexity  
A TorchioTestCase.tearDown() 0 3 1
D TorchioTestCase.get_image_path() 0 46 13
A TorchioTestCase.get_unique_labels() 0 4 1
A TorchioTestCase.assert_tensor_equal() 0 8 1
A TorchioTestCase.flip_affine_x() 0 5 2
A TorchioTestCase.make_multichannel() 0 5 2
A TorchioTestCase.get_inconsistent_shape_subject() 0 23 1
A TorchioTestCase.get_ixi_tiny() 0 3 1
A TorchioTestCase.make_2d() 0 5 2
A TorchioTestCase.get_tensor_with_labels() 0 4 1
A TorchioTestCase.get_tests_data_dir() 0 2 1
A TorchioTestCase.get_reference_image_and_path() 0 9 1
A TorchioTestCase.get_subject_with_labels() 0 6 1
A TorchioTestCase.assert_tensor_almost_equal() 0 6 1
A TorchioTestCase.get_subject_with_partial_volume_label_map() 0 11 1
A TorchioTestCase.setUp() 0 47 1
A TorchioTestCase.flip_coin() 0 2 1
A TorchioTestCase.get_large_composed_transform() 0 10 3
A TorchioTestCase.assert_tensor_all_zeros() 0 3 1
A TorchioTestCase.assert_tensor_not_equal() 0 3 2
1
import copy
2
import os
3
import random
4
import shutil
5
import tempfile
6
import unittest
7
from collections.abc import Sequence
8
from pathlib import Path
9
from random import shuffle
10
11
import numpy as np
12
import pytest
13
import torch
14
15
import torchio as tio
16
17
18
class TorchioTestCase(unittest.TestCase):
19
    def setUp(self):
20
        """Set up test fixtures, if any."""
21
        self.dir = Path(tempfile.gettempdir()) / os.urandom(24).hex()
22
        self.dir.mkdir(exist_ok=True)
23
        random.seed(42)
24
        np.random.seed(42)
25
26
        registration_matrix = np.array(
27
            [
28
                [1, 0, 0, 10],
29
                [0, 1, 0, 0],
30
                [0, 0, 1.2, 0],
31
                [0, 0, 0, 1],
32
            ]
33
        )
34
35
        subject_a = tio.Subject(
36
            t1=tio.ScalarImage(self.get_image_path('t1_a')),
37
        )
38
        subject_b = tio.Subject(
39
            t1=tio.ScalarImage(self.get_image_path('t1_b')),
40
            label=tio.LabelMap(self.get_image_path('label_b', binary=True)),
41
        )
42
        subject_c = tio.Subject(
43
            label=tio.LabelMap(self.get_image_path('label_c', binary=True)),
44
        )
45
        subject_d = tio.Subject(
46
            t1=tio.ScalarImage(
47
                self.get_image_path('t1_d'),
48
                pre_affine=registration_matrix,
49
            ),
50
            t2=tio.ScalarImage(self.get_image_path('t2_d')),
51
            label=tio.LabelMap(self.get_image_path('label_d', binary=True)),
52
        )
53
        subject_a4 = tio.Subject(
54
            t1=tio.ScalarImage(self.get_image_path('t1_a'), components=4),
55
        )
56
        self.subjects_list = [
57
            subject_a,
58
            subject_a4,
59
            subject_b,
60
            subject_c,
61
            subject_d,
62
        ]
63
        self.dataset = tio.SubjectsDataset(self.subjects_list)
64
        self.sample_subject = self.dataset[-1]  # subject_d
65
        self.subject_4d = self.dataset[1]
66
67
    def make_2d(self, subject):
68
        subject = copy.deepcopy(subject)
69
        for image in subject.get_images(intensity_only=False):
70
            image.set_data(image.data[..., :1])
71
        return subject
72
73
    def make_multichannel(self, subject):
74
        subject = copy.deepcopy(subject)
75
        for image in subject.get_images(intensity_only=False):
76
            image.set_data(torch.cat(4 * (image.data,)))
77
        return subject
78
79
    def flip_affine_x(self, subject):
80
        subject = copy.deepcopy(subject)
81
        for image in subject.get_images(intensity_only=False):
82
            image.affine = np.diag((-1, 1, 1, 1)) @ image.affine
83
        return subject
84
85
    def get_inconsistent_shape_subject(self):
86
        """Return a subject containing images of different shape."""
87
        subject = tio.Subject(
88
            t1=tio.ScalarImage(self.get_image_path('t1_inc')),
89
            t2=tio.ScalarImage(
90
                self.get_image_path('t2_inc', shape=(10, 20, 31)),
91
            ),
92
            label=tio.LabelMap(
93
                self.get_image_path(
94
                    'label_inc',
95
                    shape=(8, 17, 25),
96
                    binary=True,
97
                ),
98
            ),
99
            label2=tio.LabelMap(
100
                self.get_image_path(
101
                    'label2_inc',
102
                    shape=(18, 17, 25),
103
                    binary=True,
104
                ),
105
            ),
106
        )
107
        return subject
108
109
    def get_reference_image_and_path(self):
110
        """Return a reference image and its path."""
111
        path = self.get_image_path(
112
            'ref',
113
            shape=(10, 20, 31),
114
            spacing=(1, 1, 2),
115
        )
116
        image = tio.ScalarImage(path)
117
        return image, path
118
119
    def get_subject_with_partial_volume_label_map(self, components=1):
120
        """Return a subject with a partial-volume label map."""
121
        return tio.Subject(
122
            t1=tio.ScalarImage(
123
                self.get_image_path('t1_d'),
124
            ),
125
            label=tio.LabelMap(
126
                self.get_image_path(
127
                    'label_d2',
128
                    binary=False,
129
                    components=components,
130
                ),
131
            ),
132
        )
133
134
    def get_subject_with_labels(self, labels):
135
        return tio.Subject(
136
            label=tio.LabelMap(
137
                self.get_image_path(
138
                    'label_multi',
139
                    labels=labels,
140
                ),
141
            ),
142
        )
143
144
    @staticmethod
145
    def get_unique_labels(data: torch.Tensor) -> set[int]:
146
        labels = data.unique().tolist()
147
        return set(labels)
148
149
    @staticmethod
150
    def get_tensor_with_labels(labels: Sequence) -> torch.Tensor:
151
        tensor = torch.as_tensor(list(labels))
152
        return tensor.repeat_interleave(2).reshape(1, 1, 1, -1)
153
154
    def tearDown(self):
155
        """Tear down test fixtures, if any."""
156
        shutil.rmtree(self.dir)
157
158
    def get_ixi_tiny(self):
159
        root_dir = Path(tempfile.gettempdir()) / 'torchio' / 'ixi_tiny'
160
        return tio.datasets.IXITiny(root_dir, download=True)
161
162
    def get_image_path(
163
        self,
164
        stem,
165
        binary=False,
166
        labels=None,
167
        shape=(10, 20, 30),
168
        spacing=(1, 1, 1),
169
        components=1,
170
        add_nans=False,
171
        suffix=None,
172
        force_binary_foreground=True,
173
    ):
174
        shape = (*shape, 1) if len(shape) == 2 else shape
175
        data = np.random.rand(components, *shape)
176
        if binary:
177
            data = (data > 0.5).astype(np.uint8)
178
            if not data.sum() and force_binary_foreground:
179
                data[..., 0] = 1
180
        elif labels is not None:
181
            data = (data * (len(labels) + 1)).astype(np.uint8)
182
            new_data = np.zeros_like(data)
183
            for i, label in enumerate(labels):
184
                new_data[data == (i + 1)] = label
185
                if not (new_data == label).sum():
186
                    new_data[..., i] = label
187
            data = new_data
188
        elif self.flip_coin():  # cast some images
189
            data *= 100
190
            dtype = np.uint8 if self.flip_coin() else np.uint16
191
            data = data.astype(dtype)
192
        if add_nans:
193
            data[:] = np.nan
194
        affine = np.diag((*spacing, 1))
195
        if suffix is None:
196
            extensions = '.nii.gz', '.nii', '.nrrd', '.img', '.mnc'
197
            suffix = random.choice(extensions)
198
        path = self.dir / f'{stem}{suffix}'
199
        if self.flip_coin():
200
            path = str(path)
201
        image = tio.ScalarImage(
202
            tensor=data,
203
            affine=affine,
204
            check_nans=not add_nans,
205
        )
206
        image.save(path)
207
        return path
208
209
    def flip_coin(self):
210
        return np.random.rand() > 0.5
211
212
    def get_tests_data_dir(self):
213
        return Path(__file__).parent / 'image_data'
214
215
    def assert_tensor_not_equal(self, *args, **kwargs):  # noqa: N802
216
        with pytest.raises(AssertionError):
217
            self.assert_tensor_equal(*args, **kwargs)
218
219
    @staticmethod
220
    def assert_tensor_equal(*args, **kwargs):  # noqa: N802
221
        torch.testing.assert_close(
222
            *args,
223
            rtol=0,
224
            atol=0,
225
            check_dtype=False,
226
            **kwargs,
227
        )
228
229
    @staticmethod
230
    def assert_tensor_almost_equal(*args, **kwargs):  # noqa: N802
231
        torch.testing.assert_close(
232
            *args,
233
            **kwargs,
234
            check_dtype=False,
235
        )
236
237
    @staticmethod
238
    def assert_tensor_all_zeros(tensor):  # noqa: N802
239
        assert torch.all(tensor == 0)
240
241
    def get_large_composed_transform(self):
242
        all_classes = get_all_random_transforms()
243
        shuffle(all_classes)
244
        transforms = [t() for t in all_classes]
245
        # Hack as default patch size for RandomSwap is 15 and sample_subject
246
        # is (10, 20, 30)
247
        for tr in transforms:
248
            if tr.name == 'RandomSwap':
249
                tr.patch_size = np.array((10, 10, 10))
250
        return tio.Compose(transforms)
251
252
253
def get_all_random_transforms():
254
    transforms_names = [
255
        name for name in dir(tio.transforms) if name.startswith('Random')
256
    ]
257
    classes = [getattr(tio.transforms, name) for name in transforms_names]
258
    return classes
259