TorchioTestCase.get_inconsistent_shape_subject()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 22
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 16
nop 1
dl 0
loc 22
rs 9.6
c 0
b 0
f 0
1
import os
2
import copy
3
import shutil
4
import random
5
import tempfile
6
import unittest
7
from pathlib import Path
8
from random import shuffle
9
10
import torch
11
import numpy as np
12
from numpy.testing import assert_array_equal, assert_array_almost_equal
13
import torchio as tio
14
15
16
class TorchioTestCase(unittest.TestCase):
17
18
    def setUp(self):
19
        """Set up test fixtures, if any."""
20
        self.dir = Path(tempfile.gettempdir()) / os.urandom(24).hex()
21
        self.dir.mkdir(exist_ok=True)
22
        random.seed(42)
23
        np.random.seed(42)
24
25
        registration_matrix = np.array([
26
            [1, 0, 0, 10],
27
            [0, 1, 0, 0],
28
            [0, 0, 1.2, 0],
29
            [0, 0, 0, 1]
30
        ])
31
32
        subject_a = tio.Subject(
33
            t1=tio.ScalarImage(self.get_image_path('t1_a')),
34
        )
35
        subject_b = tio.Subject(
36
            t1=tio.ScalarImage(self.get_image_path('t1_b')),
37
            label=tio.LabelMap(self.get_image_path('label_b', binary=True)),
38
        )
39
        subject_c = tio.Subject(
40
            label=tio.LabelMap(self.get_image_path('label_c', binary=True)),
41
        )
42
        subject_d = tio.Subject(
43
            t1=tio.ScalarImage(
44
                self.get_image_path('t1_d'),
45
                pre_affine=registration_matrix,
46
            ),
47
            t2=tio.ScalarImage(self.get_image_path('t2_d')),
48
            label=tio.LabelMap(self.get_image_path('label_d', binary=True)),
49
        )
50
        subject_a4 = tio.Subject(
51
            t1=tio.ScalarImage(self.get_image_path('t1_a'), components=2),
52
        )
53
        self.subjects_list = [
54
            subject_a,
55
            subject_a4,
56
            subject_b,
57
            subject_c,
58
            subject_d,
59
        ]
60
        self.dataset = tio.SubjectsDataset(self.subjects_list)
61
        self.sample_subject = self.dataset[-1]  # subject_d
62
63
    def make_2d(self, subject):
64
        subject = copy.deepcopy(subject)
65
        for image in subject.get_images(intensity_only=False):
66
            image.set_data(image.data[..., :1])
67
        return subject
68
69
    def make_multichannel(self, subject):
70
        subject = copy.deepcopy(subject)
71
        for image in subject.get_images(intensity_only=False):
72
            image.set_data(torch.cat(4 * (image.data,)))
73
        return subject
74
75
    def flip_affine_x(self, subject):
76
        subject = copy.deepcopy(subject)
77
        for image in subject.get_images(intensity_only=False):
78
            image.affine = np.diag((-1, 1, 1, 1)) @ image.affine
79
        return subject
80
81
    def get_inconsistent_shape_subject(self):
82
        """Return a subject containing images of different shape."""
83
        subject = tio.Subject(
84
            t1=tio.ScalarImage(self.get_image_path('t1_inc')),
85
            t2=tio.ScalarImage(
86
                self.get_image_path('t2_inc', shape=(10, 20, 31))),
87
            label=tio.LabelMap(
88
                self.get_image_path(
89
                    'label_inc',
90
                    shape=(8, 17, 25),
91
                    binary=True,
92
                ),
93
            ),
94
            label2=tio.LabelMap(
95
                self.get_image_path(
96
                    'label2_inc',
97
                    shape=(18, 17, 25),
98
                    binary=True,
99
                ),
100
            ),
101
        )
102
        return subject
103
104
    def get_reference_image_and_path(self):
105
        """Return a reference image and its path"""
106
        path = self.get_image_path(
107
            'ref',
108
            shape=(10, 20, 31),
109
            spacing=(1, 1, 2),
110
        )
111
        image = tio.ScalarImage(path)
112
        return image, path
113
114
    def get_subject_with_partial_volume_label_map(self, components=1):
115
        """Return a subject with a partial-volume label map."""
116
        return tio.Subject(
117
            t1=tio.ScalarImage(
118
                self.get_image_path('t1_d'),
119
            ),
120
            label=tio.LabelMap(
121
                self.get_image_path(
122
                    'label_d2', binary=False, components=components
123
                )
124
            ),
125
        )
126
127
    def get_subject_with_labels(self, labels):
128
        return tio.Subject(
129
            label=tio.LabelMap(
130
                self.get_image_path(
131
                    'label_multi', labels=labels
132
                )
133
            )
134
        )
135
136
    def get_unique_labels(self, label_map):
137
        labels = torch.unique(label_map.data)
138
        labels = {i.item() for i in labels if i != 0}
139
        return labels
140
141
    def tearDown(self):
142
        """Tear down test fixtures, if any."""
143
        shutil.rmtree(self.dir)
144
145
    def get_ixi_tiny(self):
146
        root_dir = Path(tempfile.gettempdir()) / 'torchio' / 'ixi_tiny'
147
        return tio.datasets.IXITiny(root_dir, download=True)
148
149
    def get_image_path(
150
            self,
151
            stem,
152
            binary=False,
153
            labels=None,
154
            shape=(10, 20, 30),
155
            spacing=(1, 1, 1),
156
            components=1,
157
            add_nans=False,
158
            suffix=None,
159
            force_binary_foreground=True,
160
            ):
161
        shape = (*shape, 1) if len(shape) == 2 else shape
162
        data = np.random.rand(components, *shape)
163
        if binary:
164
            data = (data > 0.5).astype(np.uint8)
165
            if not data.sum() and force_binary_foreground:
166
                data[..., 0] = 1
167
        elif labels is not None:
168
            data = (data * (len(labels) + 1)).astype(np.uint8)
169
            new_data = np.zeros_like(data)
170
            for i, label in enumerate(labels):
171
                new_data[data == (i + 1)] = label
172
                if not (new_data == label).sum():
173
                    new_data[..., i] = label
174
            data = new_data
175
        elif self.flip_coin():  # cast some images
176
            data *= 100
177
            dtype = np.uint8 if self.flip_coin() else np.uint16
178
            data = data.astype(dtype)
179
        if add_nans:
180
            data[:] = np.nan
181
        affine = np.diag((*spacing, 1))
182
        if suffix is None:
183
            extensions = '.nii.gz', '.nii', '.nrrd', '.img', '.mnc'
184
            suffix = random.choice(extensions)
185
        path = self.dir / f'{stem}{suffix}'
186
        if self.flip_coin():
187
            path = str(path)
188
        image = tio.ScalarImage(
189
            tensor=data,
190
            affine=affine,
191
            check_nans=not add_nans,
192
        )
193
        image.save(path)
194
        return path
195
196
    def flip_coin(self):
197
        return np.random.rand() > 0.5
198
199
    def get_tests_data_dir(self):
200
        return Path(__file__).parent / 'image_data'
201
202
    def assertTensorNotEqual(self, *args, **kwargs):  # noqa: N802
203
        message_kwarg = {'msg': args[2]} if len(args) == 3 else {}
204
        with self.assertRaises(AssertionError, **message_kwarg):
205
            self.assertTensorEqual(*args, **kwargs)
206
207
    @staticmethod
208
    def assertTensorEqual(*args, **kwargs):  # noqa: N802
209
        assert_array_equal(*args, **kwargs)
210
211
    @staticmethod
212
    def assertTensorAlmostEqual(*args, **kwargs):  # noqa: N802
213
        assert_array_almost_equal(*args, **kwargs)
214
215
    def get_large_composed_transform(self):
216
        all_classes = get_all_random_transforms()
217
        shuffle(all_classes)
218
        transforms = [t() for t in all_classes]
219
        # Hack as default patch size for RandomSwap is 15 and sample_subject
220
        # is (10, 20, 30)
221
        for tr in transforms:
222
            if tr.name == 'RandomSwap':
223
                tr.patch_size = np.array((10, 10, 10))
224
        return tio.Compose(transforms)
225
226
227
def get_all_random_transforms():
228
    transforms_names = [
229
        name
230
        for name in dir(tio.transforms)
231
        if name.startswith('Random')
232
    ]
233
    classes = [getattr(tio.transforms, name) for name in transforms_names]
234
    return classes
235