Passed
Push — master ( 0bf8ef...e85db2 )
by Fernando
01:12
created

tests.data.test_io.TestIO.test_sitk_to_affine()   A

Complexity

Conditions 1

Size

Total Lines 13
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 13
rs 9.75
c 0
b 0
f 0
cc 1
nop 1
1
import tempfile
2
from pathlib import Path
3
4
import torch
5
import pytest
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ..utils import TorchioTestCase
10
from torchio.data import io, ScalarImage
11
12
13
class TestIO(TorchioTestCase):
14
    """Tests for `io` module."""
15
    def setUp(self):
16
        super().setUp()
17
        self.nii_path = self.get_image_path('read_image')
18
        self.dicom_dir = self.get_tests_data_dir() / 'dicom'
19
        self.dicom_path = self.dicom_dir / 'IMG0001.dcm'
20
        string = (
21
            '1.5 0.18088 -0.124887 0.65072 '
22
            '-0.20025 0.965639 -0.165653 -11.6452 '
23
            '0.0906326 0.18661 0.978245 11.4002 '
24
            '0 0 0 1 '
25
        )
26
        tensor = torch.as_tensor(np.fromstring(string, sep=' ').reshape(4, 4))
27
        self.matrix = tensor
28
29
    def test_read_image(self):
30
        # I need to find something readable by nib but not sitk
31
        io.read_image(self.nii_path)
32
33
    def test_save_rgb(self):
34
        im = ScalarImage(tensor=torch.rand(1, 4, 5, 1))
35
        with self.assertWarns(RuntimeWarning):
36
            im.save(self.dir / 'test.jpg')
37
38
    def test_read_dicom_file(self):
39
        tensor, _ = io.read_image(self.dicom_path)
40
        self.assertEqual(tuple(tensor.shape), (1, 88, 128, 1))
41
42
    def test_read_dicom_dir(self):
43
        tensor, _ = io.read_image(self.dicom_dir)
44
        self.assertEqual(tuple(tensor.shape), (1, 88, 128, 17))
45
46
    def test_dicom_dir_missing(self):
47
        with self.assertRaises(FileNotFoundError):
48
            io._read_dicom('missing')
49
50
    def test_dicom_dir_no_files(self):
51
        empty = self.dir / 'empty'
52
        empty.mkdir()
53
        with self.assertRaises(FileNotFoundError):
54
            io._read_dicom(empty)
55
56
    def write_read_matrix(self, suffix):
57
        out_path = self.dir / f'matrix{suffix}'
58
        io.write_matrix(self.matrix, out_path)
59
        matrix = io.read_matrix(out_path)
60
        assert torch.allclose(matrix, self.matrix)
61
62
    def test_matrix_itk(self):
63
        self.write_read_matrix('.tfm')
64
        self.write_read_matrix('.h5')
65
66
    def test_matrix_txt(self):
67
        self.write_read_matrix('.txt')
68
69
    def test_ensure_4d_5d(self):
70
        tensor = torch.rand(3, 4, 5, 1, 2)
71
        assert io.ensure_4d(tensor).shape == (2, 3, 4, 5)
72
73
    def test_ensure_4d_5d_t_gt_1(self):
74
        tensor = torch.rand(3, 4, 5, 2, 2)
75
        with self.assertRaises(ValueError):
76
            io.ensure_4d(tensor)
77
78
    def test_ensure_4d_2d(self):
79
        tensor = torch.rand(4, 5)
80
        assert io.ensure_4d(tensor).shape == (1, 4, 5, 1)
81
82
    def test_ensure_4d_2d_3dims_rgb_first(self):
83
        tensor = torch.rand(3, 4, 5)
84
        assert io.ensure_4d(tensor).shape == (3, 4, 5, 1)
85
86
    def test_ensure_4d_2d_3dims_rgb_last(self):
87
        tensor = torch.rand(4, 5, 3)
88
        assert io.ensure_4d(tensor).shape == (3, 4, 5, 1)
89
90
    def test_ensure_4d_3d(self):
91
        tensor = torch.rand(4, 5, 6)
92
        assert io.ensure_4d(tensor).shape == (1, 4, 5, 6)
93
94
    def test_ensure_4d_2_spatial_dims(self):
95
        tensor = torch.rand(4, 5, 6)
96
        assert io.ensure_4d(tensor, num_spatial_dims=2).shape == (4, 5, 6, 1)
97
98
    def test_ensure_4d_3_spatial_dims(self):
99
        tensor = torch.rand(4, 5, 6)
100
        assert io.ensure_4d(tensor, num_spatial_dims=3).shape == (1, 4, 5, 6)
101
102
    def test_ensure_4d_nd_not_supported(self):
103
        tensor = torch.rand(1, 2, 3, 4, 5)
104
        with self.assertRaises(ValueError):
105
            io.ensure_4d(tensor)
106
107
    def test_sitk_to_nib(self):
108
        data = np.random.rand(10, 12)
109
        image = sitk.GetImageFromArray(data)
110
        tensor, _ = io.sitk_to_nib(image)
111
        self.assertAlmostEqual(data.sum(), tensor.sum())
112
113
    def test_sitk_to_affine(self):
114
        spacing = 1, 2, 3
115
        direction_lps = -1, 0, 0, 0, -1, 0, 0, 0, 1
116
        origin_lps = l, p, s = -10, -20, 30
117
        image = sitk.GetImageFromArray(np.random.rand(10, 20, 30))
118
        image.SetDirection(direction_lps)
119
        image.SetSpacing(spacing)
120
        image.SetOrigin(origin_lps)
121
        origin_ras = -l, -p, s
122
        fixture = np.diag((*spacing, 1))
123
        fixture[:3, 3] = origin_ras
124
        affine = io.get_ras_affine_from_sitk(image)
125
        self.assertTensorAlmostEqual(fixture, affine)
126
127
128
# This doesn't work as a method of the class
129
libs = 'sitk', 'nibabel'
130
parameters = []
131
for save_lib in libs:
132
    for load_lib in libs:
133
        for dims in 2, 3, 4:
134
            parameters.append((save_lib, load_lib, dims))
135
136
137
@pytest.mark.parametrize(('save_lib', 'load_lib', 'dims'), parameters)
138
def test_write_nd_with_a_read_it_with_b(save_lib, load_lib, dims):
139
    shape = [1, 4, 5, 6]
140
    if dims == 2:
141
        shape[-1] = 1
142
    elif dims == 4:
143
        shape[0] = 2
144
    tensor = torch.randn(*shape)
145
    affine = np.eye(4)
146
    tempdir = Path(tempfile.gettempdir()) / '.torchio_tests'
147
    tempdir.mkdir(exist_ok=True)
148
    path = tempdir / 'test_io.nii'
149
    save_function = getattr(io, f'_write_{save_lib}')
150
    load_function = getattr(io, f'_read_{save_lib}')
151
    save_function(tensor, affine, path)
152
    loaded_tensor, loaded_affine = load_function(path)
153
    TorchioTestCase.assertTensorEqual(
154
        tensor.squeeze(), loaded_tensor.squeeze(),
155
        f'Save lib: {save_lib}; load lib: {load_lib}; dims: {dims}'
156
    )
157
    TorchioTestCase.assertTensorEqual(affine, loaded_affine)
158
159
160
class TestNibabelToSimpleITK(TorchioTestCase):
161
162
    def setUp(self):
163
        super().setUp()
164
        self.affine = np.eye(4)
165
166
    def test_wrong_num_dims(self):
167
        with self.assertRaises(ValueError):
168
            io.nib_to_sitk(np.random.rand(10, 10), self.affine)
169
170
    def test_2d_single(self):
171
        data = np.random.rand(1, 10, 12, 1)
172
        image = io.nib_to_sitk(data, self.affine)
173
        assert image.GetDimension() == 2
174
        assert image.GetSize() == (10, 12)
175
        assert image.GetNumberOfComponentsPerPixel() == 1
176
177
    def test_2d_multi(self):
178
        data = np.random.rand(5, 10, 12, 1)
179
        image = io.nib_to_sitk(data, self.affine)
180
        assert image.GetDimension() == 2
181
        assert image.GetSize() == (10, 12)
182
        assert image.GetNumberOfComponentsPerPixel() == 5
183
184
    def test_2d_3d_single(self):
185
        data = np.random.rand(1, 10, 12, 1)
186
        image = io.nib_to_sitk(data, self.affine, force_3d=True)
187
        assert image.GetDimension() == 3
188
        assert image.GetSize() == (10, 12, 1)
189
        assert image.GetNumberOfComponentsPerPixel() == 1
190
191
    def test_2d_3d_multi(self):
192
        data = np.random.rand(5, 10, 12, 1)
193
        image = io.nib_to_sitk(data, self.affine, force_3d=True)
194
        assert image.GetDimension() == 3
195
        assert image.GetSize() == (10, 12, 1)
196
        assert image.GetNumberOfComponentsPerPixel() == 5
197
198
    def test_3d_single(self):
199
        data = np.random.rand(1, 8, 10, 12)
200
        image = io.nib_to_sitk(data, self.affine)
201
        assert image.GetDimension() == 3
202
        assert image.GetSize() == (8, 10, 12)
203
        assert image.GetNumberOfComponentsPerPixel() == 1
204
205
    def test_3d_multi(self):
206
        data = np.random.rand(5, 8, 10, 12)
207
        image = io.nib_to_sitk(data, self.affine)
208
        assert image.GetDimension() == 3
209
        assert image.GetSize() == (8, 10, 12)
210
        assert image.GetNumberOfComponentsPerPixel() == 5
211