Passed
Pull Request — master (#380)
by Fernando
01:25
created

tests.utils.TorchioTestCase.flip_coin()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
import copy
2
import shutil
3
import random
4
import tempfile
5
import unittest
6
from pathlib import Path
7
from random import shuffle
8
9
import torch
10
import numpy as np
11
from numpy.testing import assert_array_equal, assert_array_almost_equal
12
import torchio as tio
13
14
15
class TorchioTestCase(unittest.TestCase):
16
17
    def setUp(self):
18
        """Set up test fixtures, if any."""
19
        self.dir = Path(tempfile.gettempdir()) / '.torchio_tests'
20
        self.dir.mkdir(exist_ok=True)
21
        random.seed(42)
22
        np.random.seed(42)
23
24
        registration_matrix = np.array([
25
            [1, 0, 0, 10],
26
            [0, 1, 0, 0],
27
            [0, 0, 1.2, 0],
28
            [0, 0, 0, 1]
29
        ])
30
31
        subject_a = tio.Subject(
32
            t1=tio.ScalarImage(self.get_image_path('t1_a')),
33
        )
34
        subject_b = tio.Subject(
35
            t1=tio.ScalarImage(self.get_image_path('t1_b')),
36
            label=tio.LabelMap(self.get_image_path('label_b', binary=True)),
37
        )
38
        subject_c = tio.Subject(
39
            label=tio.LabelMap(self.get_image_path('label_c', binary=True)),
40
        )
41
        subject_d = tio.Subject(
42
            t1=tio.ScalarImage(
43
                self.get_image_path('t1_d'),
44
                pre_affine=registration_matrix,
45
            ),
46
            t2=tio.ScalarImage(self.get_image_path('t2_d')),
47
            label=tio.LabelMap(self.get_image_path('label_d', binary=True)),
48
        )
49
        subject_a4 = tio.Subject(
50
            t1=tio.ScalarImage(self.get_image_path('t1_a'), components=2),
51
        )
52
        self.subjects_list = [
53
            subject_a,
54
            subject_a4,
55
            subject_b,
56
            subject_c,
57
            subject_d,
58
        ]
59
        self.dataset = tio.SubjectsDataset(self.subjects_list)
60
        self.sample_subject = self.dataset[-1]  # subject_d
61
62
    def make_2d(self, subject):
63
        subject = copy.deepcopy(subject)
64
        for image in subject.get_images(intensity_only=False):
65
            image.data = image.data[..., :1]
66
        return subject
67
68
    def make_multichannel(self, subject):
69
        subject = copy.deepcopy(subject)
70
        for image in subject.get_images(intensity_only=False):
71
            image.data = torch.cat(4 * (image.data,))
72
        return subject
73
74
    def flip_affine_x(self, subject):
75
        subject = copy.deepcopy(subject)
76
        for image in subject.get_images(intensity_only=False):
77
            image.affine = np.diag((-1, 1, 1, 1)) @ image.affine
78
        return subject
79
80
    def get_inconsistent_shape_subject(self):
81
        """Return a subject containing images of different shape."""
82
        subject = tio.Subject(
83
            t1=tio.ScalarImage(self.get_image_path('t1_inc')),
84
            t2=tio.ScalarImage(
85
                self.get_image_path('t2_inc', shape=(10, 20, 31))),
86
            label=tio.LabelMap(
87
                self.get_image_path(
88
                    'label_inc',
89
                    shape=(8, 17, 25),
90
                    binary=True,
91
                ),
92
            ),
93
            label2=tio.LabelMap(
94
                self.get_image_path(
95
                    'label2_inc',
96
                    shape=(18, 17, 25),
97
                    binary=True,
98
                ),
99
            ),
100
        )
101
        return subject
102
103
    def get_reference_image_and_path(self):
104
        """Return a reference image and its path"""
105
        path = self.get_image_path('ref', shape=(10, 20, 31), spacing=(1, 1, 2))
106
        image = tio.ScalarImage(path)
107
        return image, path
108
109
    def get_subject_with_partial_volume_label_map(self, components=1):
110
        """Return a subject with a partial-volume label map."""
111
        return tio.Subject(
112
            t1=tio.ScalarImage(
113
                self.get_image_path('t1_d'),
114
            ),
115
            label=tio.LabelMap(
116
                self.get_image_path(
117
                    'label_d2', binary=False, components=components
118
                )
119
            ),
120
        )
121
122
    def tearDown(self):
123
        """Tear down test fixtures, if any."""
124
        shutil.rmtree(self.dir)
125
126
    def get_ixi_tiny(self):
127
        root_dir = Path(tempfile.gettempdir()) / 'torchio' / 'ixi_tiny'
128
        return tio.datasets.IXITiny(root_dir, download=True)
129
130
    def get_image_path(
131
            self,
132
            stem,
133
            binary=False,
134
            shape=(10, 20, 30),
135
            spacing=(1, 1, 1),
136
            components=1,
137
            add_nans=False,
138
            suffix=None,
139
            force_binary_foreground=True,
140
            ):
141
        shape = (*shape, 1) if len(shape) == 2 else shape
142
        data = np.random.rand(components, *shape)
143
        if binary:
144
            data = (data > 0.5).astype(np.uint8)
145
            if not data.sum() and force_binary_foreground:
146
                data[..., 0] = 1
147
        elif self.flip_coin():  # cast some images
148
            data *= 100
149
            dtype = np.uint8 if self.flip_coin() else np.uint16
150
            data = data.astype(dtype)
151
        if add_nans:
152
            data[:] = np.nan
153
        affine = np.diag((*spacing, 1))
154
        if suffix is None:
155
            suffix = random.choice(('.nii.gz', '.nii', '.nrrd', '.img', '.mnc'))
156
        path = self.dir / f'{stem}{suffix}'
157
        if self.flip_coin():
158
            path = str(path)
159
        image = tio.ScalarImage(
160
            tensor=data,
161
            affine=affine,
162
            check_nans=not add_nans,
163
        )
164
        image.save(path)
165
        return path
166
167
    def flip_coin(self):
168
        return np.random.rand() > 0.5
169
170
    def get_tests_data_dir(self):
171
        return Path(__file__).parent / 'image_data'
172
173
    def assertTensorNotEqual(self, *args, **kwargs):  # noqa: N802
174
        message_kwarg = dict(msg=args[2]) if len(args) == 3 else {}
175
        with self.assertRaises(AssertionError, **message_kwarg):
176
            self.assertTensorEqual(*args, **kwargs)
177
178
    @staticmethod
179
    def assertTensorEqual(*args, **kwargs):  # noqa: N802
180
        assert_array_equal(*args, **kwargs)
181
182
    @staticmethod
183
    def assertTensorAlmostEqual(*args, **kwargs):  # noqa: N802
184
        assert_array_almost_equal(*args, **kwargs)
185
186
    def get_large_composed_transform(self):
187
        all_classes = get_all_random_transforms()
188
        shuffle(all_classes)
189
        transforms = [t() for t in all_classes]
190
        # Hack as default patch size for RandomSwap is 15 and sample_subject
191
        # is (10, 20, 30)
192
        for tr in transforms:
193
            if tr.name == 'RandomSwap':
194
                tr.patch_size = np.array((10, 10, 10))
195
        return tio.Compose(transforms)
196
197
198
def get_all_random_transforms():
199
    transforms_names = [
200
        name
201
        for name in dir(tio.transforms)
202
        if name.startswith('Random')
203
    ]
204
    classes = [getattr(tio.transforms, name) for name in transforms_names]
205
    return classes
206