Passed
Pull Request — master (#332)
by Fernando
04:34
created

tests.utils.TorchioTestCase.get_tests_data_dir()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
import copy
2
import shutil
3
import random
4
import tempfile
5
import unittest
6
from pathlib import Path
7
8
import torch
9
import numpy as np
10
import nibabel as nib
11
from numpy.testing import assert_array_equal, assert_array_almost_equal
12
from torchio.datasets import IXITiny
13
from torchio import DATA, AFFINE
14
from torchio import ScalarImage, LabelMap, SubjectsDataset, Subject
15
16
17
class TorchioTestCase(unittest.TestCase):
18
19
    def setUp(self):
20
        """Set up test fixtures, if any."""
21
        self.dir = Path(tempfile.gettempdir()) / '.torchio_tests'
22
        self.dir.mkdir(exist_ok=True)
23
        random.seed(42)
24
        np.random.seed(42)
25
26
        registration_matrix = np.array([
27
            [1, 0, 0, 10],
28
            [0, 1, 0, 0],
29
            [0, 0, 1.2, 0],
30
            [0, 0, 0, 1]
31
        ])
32
33
        subject_a = Subject(
34
            t1=ScalarImage(self.get_image_path('t1_a')),
35
        )
36
        subject_b = Subject(
37
            t1=ScalarImage(self.get_image_path('t1_b')),
38
            label=LabelMap(self.get_image_path('label_b', binary=True)),
39
        )
40
        subject_c = Subject(
41
            label=LabelMap(self.get_image_path('label_c', binary=True)),
42
        )
43
        subject_d = Subject(
44
            t1=ScalarImage(
45
                self.get_image_path('t1_d'),
46
                pre_affine=registration_matrix,
47
            ),
48
            t2=ScalarImage(self.get_image_path('t2_d')),
49
            label=LabelMap(self.get_image_path('label_d', binary=True)),
50
        )
51
        subject_a4 = Subject(
52
            t1=ScalarImage(self.get_image_path('t1_a'), components=2),
53
        )
54
        self.subjects_list = [
55
            subject_a,
56
            subject_a4,
57
            subject_b,
58
            subject_c,
59
            subject_d,
60
        ]
61
        self.dataset = SubjectsDataset(self.subjects_list)
62
        self.sample = self.dataset[-1]  # subject_d
63
64
    def make_2d(self, sample):
65
        sample = copy.deepcopy(sample)
66
        for image in sample.get_images(intensity_only=False):
67
            image[DATA] = image[DATA][..., :1]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
68
        return sample
69
70
    def make_multichannel(self, sample):
71
        sample = copy.deepcopy(sample)
72
        for image in sample.get_images(intensity_only=False):
73
            image[DATA] = torch.cat(4 * (image[DATA],))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
74
        return sample
75
76
    def flip_affine_x(self, sample):
77
        sample = copy.deepcopy(sample)
78
        for image in sample.get_images(intensity_only=False):
79
            image[AFFINE] = np.diag((-1, 1, 1, 1)) @ image[AFFINE]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable AFFINE does not seem to be defined.
Loading history...
80
        return sample
81
82
    def get_inconsistent_sample(self):
83
        """Return a sample containing images of different shape."""
84
        subject = Subject(
85
            t1=ScalarImage(self.get_image_path('t1_inc')),
86
            t2=ScalarImage(
87
                self.get_image_path('t2_inc', shape=(10, 20, 31))),
88
            label=LabelMap(
89
                self.get_image_path(
90
                    'label_inc',
91
                    shape=(8, 17, 25),
92
                    binary=True,
93
                ),
94
            ),
95
            label2=LabelMap(
96
                self.get_image_path(
97
                    'label2_inc',
98
                    shape=(18, 17, 25),
99
                    binary=True,
100
                ),
101
            ),
102
        )
103
        return subject
104
105
    def get_reference_image_and_path(self):
106
        """Return a reference image and its path"""
107
        path = self.get_image_path('ref', shape=(10, 20, 31), spacing=(1, 1, 2))
108
        image = ScalarImage(path)
109
        return image, path
110
111
    def get_sample_with_partial_volume_label_map(self, components=1):
112
        """Return a sample with a partial-volume label map."""
113
        return Subject(
114
            t1=ScalarImage(
115
                self.get_image_path('t1_d'),
116
            ),
117
            label=LabelMap(
118
                self.get_image_path(
119
                    'label_d2', binary=False, components=components
120
                )
121
            ),
122
        )
123
124
    def tearDown(self):
125
        """Tear down test fixtures, if any."""
126
        shutil.rmtree(self.dir)
127
128
    def get_ixi_tiny(self):
129
        root_dir = Path(tempfile.gettempdir()) / 'torchio' / 'ixi_tiny'
130
        return IXITiny(root_dir, download=True)
131
132
    def get_image_path(
133
            self,
134
            stem,
135
            binary=False,
136
            shape=(10, 20, 30),
137
            spacing=(1, 1, 1),
138
            components=1,
139
            add_nans=False,
140
            suffix=None,
141
            force_binary_foreground=True,
142
            ):
143
        shape = (*shape, 1) if len(shape) == 2 else shape
144
        data = np.random.rand(components, *shape)
145
        if binary:
146
            data = (data > 0.5).astype(np.uint8)
147
            if not data.sum() and force_binary_foreground:
148
                data[..., 0] = 1
149
        if add_nans:
150
            data[:] = np.nan
151
        affine = np.diag((*spacing, 1))
152
        if suffix is None:
153
            suffix = random.choice(('.nii.gz', '.nii', '.nrrd', '.img', '.mnc'))
154
        path = self.dir / f'{stem}{suffix}'
155
        if np.random.rand() > 0.5:
156
            path = str(path)
157
        image = ScalarImage(tensor=data, affine=affine, check_nans=not add_nans)
158
        image.save(path)
159
        return path
160
161
    def get_tests_data_dir(self):
162
        return Path(__file__).parent / 'image_data'
163
164
    def assertTensorNotEqual(self, *args, **kwargs):  # noqa: N802
165
        message_kwarg = dict(msg=args[2]) if len(args) == 3 else {}
166
        with self.assertRaises(AssertionError, **message_kwarg):
167
            self.assertTensorEqual(*args, **kwargs)
168
169
    def assertTensorEqual(self, *args, **kwargs):  # noqa: N802
170
        assert_array_equal(*args, **kwargs)
171
172
    def assertTensorAlmostEqual(self, *args, **kwargs):  # noqa: N802
173
        assert_array_almost_equal(*args, **kwargs)
174