Passed
Push — master ( 10dd9a...d701c3 )
by Fernando
01:33
created

TorchioTestCase.get_inconsistent_sample()   A

Complexity

Conditions 1

Size

Total Lines 22
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 16
nop 1
dl 0
loc 22
rs 9.6
c 0
b 0
f 0
1
import copy
2
import shutil
3
import random
4
import tempfile
5
import unittest
6
from pathlib import Path
7
8
import torch
9
import numpy as np
10
import nibabel as nib
11
from numpy.testing import (
12
    assert_array_equal,
13
    assert_array_almost_equal,
14
    assert_raises,
15
)
16
from torchio.datasets import IXITiny
17
from torchio import DATA, AFFINE
18
from torchio import ScalarImage, LabelMap, SubjectsDataset, Subject
19
20
21
class TorchioTestCase(unittest.TestCase):
22
23
    def setUp(self):
24
        """Set up test fixtures, if any."""
25
        self.dir = Path(tempfile.gettempdir()) / '.torchio_tests'
26
        self.dir.mkdir(exist_ok=True)
27
        random.seed(42)
28
        np.random.seed(42)
29
30
        registration_matrix = np.array([
31
            [1, 0, 0, 10],
32
            [0, 1, 0, 0],
33
            [0, 0, 1.2, 0],
34
            [0, 0, 0, 1]
35
        ])
36
37
        subject_a = Subject(
38
            t1=ScalarImage(self.get_image_path('t1_a')),
39
        )
40
        subject_b = Subject(
41
            t1=ScalarImage(self.get_image_path('t1_b')),
42
            label=LabelMap(self.get_image_path('label_b', binary=True)),
43
        )
44
        subject_c = Subject(
45
            label=LabelMap(self.get_image_path('label_c', binary=True)),
46
        )
47
        subject_d = Subject(
48
            t1=ScalarImage(
49
                self.get_image_path('t1_d'),
50
                pre_affine=registration_matrix,
51
            ),
52
            t2=ScalarImage(self.get_image_path('t2_d')),
53
            label=LabelMap(self.get_image_path('label_d', binary=True)),
54
        )
55
        subject_a4 = Subject(
56
            t1=ScalarImage(self.get_image_path('t1_a'), components=2),
57
        )
58
        self.subjects_list = [
59
            subject_a,
60
            subject_a4,
61
            subject_b,
62
            subject_c,
63
            subject_d,
64
        ]
65
        self.dataset = SubjectsDataset(self.subjects_list)
66
        self.sample = self.dataset[-1]  # subject_d
67
68
    def make_2d(self, sample):
69
        sample = copy.deepcopy(sample)
70
        for image in sample.get_images(intensity_only=False):
71
            image[DATA] = image[DATA][..., :1]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
72
        return sample
73
74
    def make_multichannel(self, sample):
75
        sample = copy.deepcopy(sample)
76
        for image in sample.get_images(intensity_only=False):
77
            image[DATA] = torch.cat(4 * (image[DATA],))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
78
        return sample
79
80
    def flip_affine_x(self, sample):
81
        sample = copy.deepcopy(sample)
82
        for image in sample.get_images(intensity_only=False):
83
            image[AFFINE] = np.diag((-1, 1, 1, 1)) @ image[AFFINE]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable AFFINE does not seem to be defined.
Loading history...
84
        return sample
85
86
    def get_inconsistent_sample(self):
87
        """Return a sample containing images of different shape."""
88
        subject = Subject(
89
            t1=ScalarImage(self.get_image_path('t1_inc')),
90
            t2=ScalarImage(
91
                self.get_image_path('t2_inc', shape=(10, 20, 31))),
92
            label=LabelMap(
93
                self.get_image_path(
94
                    'label_inc',
95
                    shape=(8, 17, 25),
96
                    binary=True,
97
                ),
98
            ),
99
            label2=LabelMap(
100
                self.get_image_path(
101
                    'label2_inc',
102
                    shape=(18, 17, 25),
103
                    binary=True,
104
                ),
105
            ),
106
        )
107
        return subject
108
109
    def get_reference_image_and_path(self):
110
        """Return a reference image and its path"""
111
        path = self.get_image_path('ref', shape=(10, 20, 31), spacing=(1, 1, 2))
112
        image = ScalarImage(path)
113
        return image, path
114
115
    def get_sample_with_partial_volume_label_map(self, components=1):
116
        """Return a sample with a partial-volume label map."""
117
        return Subject(
118
            t1=ScalarImage(
119
                self.get_image_path('t1_d'),
120
            ),
121
            label=LabelMap(
122
                self.get_image_path(
123
                    'label_d2', binary=False, components=components
124
                )
125
            ),
126
        )
127
128
    def tearDown(self):
129
        """Tear down test fixtures, if any."""
130
        shutil.rmtree(self.dir)
131
132
    def get_ixi_tiny(self):
133
        root_dir = Path(tempfile.gettempdir()) / 'torchio' / 'ixi_tiny'
134
        return IXITiny(root_dir, download=True)
135
136
    def get_image_path(
137
            self,
138
            stem,
139
            binary=False,
140
            shape=(10, 20, 30),
141
            spacing=(1, 1, 1),
142
            components=1,
143
            add_nans=False,
144
            suffix=None,
145
            force_binary_foreground=True,
146
            ):
147
        shape = (*shape, 1) if len(shape) == 2 else shape
148
        data = np.random.rand(components, *shape)
149
        if binary:
150
            data = (data > 0.5).astype(np.uint8)
151
            if not data.sum() and force_binary_foreground:
152
                data[..., 0] = 1
153
        if add_nans:
154
            data[:] = np.nan
155
        affine = np.diag((*spacing, 1))
156
        if suffix is None:
157
            suffix = random.choice(('.nii.gz', '.nii', '.nrrd', '.img', '.mnc'))
158
        path = self.dir / f'{stem}{suffix}'
159
        if np.random.rand() > 0.5:
160
            path = str(path)
161
        image = ScalarImage(tensor=data, affine=affine, check_nans=not add_nans)
162
        image.save(path)
163
        return path
164
165
    def assertTensorNotEqual(self, *args, **kwargs):
166
        assert_raises(AssertionError, assert_array_equal, *args, **kwargs)
167
168
    def assertTensorEqual(self, *args, **kwargs):
169
        assert_array_equal(*args, **kwargs)
170
171
    def assertTensorAlmostEqual(self, *args, **kwargs):
172
        assert_array_almost_equal(*args, **kwargs)
173