Passed
Push — master ( 0e3b0b...4497b8 )
by Fernando
01:17
created

TestNibabelToSimpleITK.test_3d_single()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 1
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
import tempfile
2
from pathlib import Path
3
4
import torch
5
import pytest
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ..utils import TorchioTestCase
10
from torchio.data import io, ScalarImage
11
from torchio.data.io import ensure_4d, nib_to_sitk, sitk_to_nib
12
13
14
class TestIO(TorchioTestCase):
15
    """Tests for `io` module."""
16
    def setUp(self):
17
        super().setUp()
18
        self.nii_path = self.get_image_path('read_image')
19
        self.dicom_dir = self.get_tests_data_dir() / 'dicom'
20
        self.dicom_path = self.dicom_dir / 'IMG0001.dcm'
21
        string = (
22
            '1.5 0.18088 -0.124887 0.65072 '
23
            '-0.20025 0.965639 -0.165653 -11.6452 '
24
            '0.0906326 0.18661 0.978245 11.4002 '
25
            '0 0 0 1 '
26
        )
27
        tensor = torch.from_numpy(np.fromstring(string, sep=' ').reshape(4, 4))
28
        self.matrix = tensor
29
30
    def test_read_image(self):
31
        # I need to find something readable by nib but not sitk
32
        io.read_image(self.nii_path)
33
34
    def test_save_rgb(self):
35
        im = ScalarImage(tensor=torch.rand(1, 4, 5, 1))
36
        with self.assertWarns(RuntimeWarning):
37
            im.save(self.dir / 'test.jpg')
38
39
    def test_read_dicom_file(self):
40
        tensor, _ = io.read_image(self.dicom_path)
41
        self.assertEqual(tuple(tensor.shape), (1, 88, 128, 1))
42
43
    def test_read_dicom_dir(self):
44
        tensor, _ = io.read_image(self.dicom_dir)
45
        self.assertEqual(tuple(tensor.shape), (1, 88, 128, 17))
46
47
    def test_dicom_dir_missing(self):
48
        with self.assertRaises(FileNotFoundError):
49
            io._read_dicom('missing')
50
51
    def test_dicom_dir_no_files(self):
52
        empty = self.dir / 'empty'
53
        empty.mkdir()
54
        with self.assertRaises(FileNotFoundError):
55
            io._read_dicom(empty)
56
57
    def write_read_matrix(self, suffix):
58
        out_path = self.dir / f'matrix{suffix}'
59
        io.write_matrix(self.matrix, out_path)
60
        matrix = io.read_matrix(out_path)
61
        assert torch.allclose(matrix, self.matrix)
62
63
    def test_matrix_itk(self):
64
        self.write_read_matrix('.tfm')
65
        self.write_read_matrix('.h5')
66
67
    def test_matrix_txt(self):
68
        self.write_read_matrix('.txt')
69
70
    def test_ensure_4d_5d(self):
71
        tensor = torch.rand(3, 4, 5, 1, 2)
72
        assert ensure_4d(tensor).shape == (2, 3, 4, 5)
73
74
    def test_ensure_4d_5d_t_gt_1(self):
75
        tensor = torch.rand(3, 4, 5, 2, 2)
76
        with self.assertRaises(ValueError):
77
            ensure_4d(tensor)
78
79
    def test_ensure_4d_2d(self):
80
        tensor = torch.rand(4, 5)
81
        assert ensure_4d(tensor).shape == (1, 4, 5, 1)
82
83
    def test_ensure_4d_2d_3dims_rgb_first(self):
84
        tensor = torch.rand(3, 4, 5)
85
        assert ensure_4d(tensor).shape == (3, 4, 5, 1)
86
87
    def test_ensure_4d_2d_3dims_rgb_last(self):
88
        tensor = torch.rand(4, 5, 3)
89
        assert ensure_4d(tensor).shape == (3, 4, 5, 1)
90
91
    def test_ensure_4d_3d(self):
92
        tensor = torch.rand(4, 5, 6)
93
        assert ensure_4d(tensor).shape == (1, 4, 5, 6)
94
95
    def test_ensure_4d_2_spatial_dims(self):
96
        tensor = torch.rand(4, 5, 6)
97
        assert ensure_4d(tensor, num_spatial_dims=2).shape == (4, 5, 6, 1)
98
99
    def test_ensure_4d_3_spatial_dims(self):
100
        tensor = torch.rand(4, 5, 6)
101
        assert ensure_4d(tensor, num_spatial_dims=3).shape == (1, 4, 5, 6)
102
103
    def test_ensure_4d_nd_not_supported(self):
104
        tensor = torch.rand(1, 2, 3, 4, 5)
105
        with self.assertRaises(ValueError):
106
            ensure_4d(tensor)
107
108
    def test_sitk_to_nib(self):
109
        data = np.random.rand(10, 12)
110
        image = sitk.GetImageFromArray(data)
111
        tensor, affine = sitk_to_nib(image)
112
        self.assertAlmostEqual(data.sum(), tensor.sum())
113
114
115
# This doesn't work as a method of the class
116
libs = 'sitk', 'nibabel'
117
parameters = []
118
for save_lib in libs:
119
    for load_lib in libs:
120
        for dims in 2, 3, 4:
121
            parameters.append((save_lib, load_lib, dims))
122
123
124
@pytest.mark.parametrize(('save_lib', 'load_lib', 'dims'), parameters)
125
def test_write_nd_with_a_read_it_with_b(save_lib, load_lib, dims):
126
    shape = [1, 4, 5, 6]
127
    if dims == 2:
128
        shape[-1] = 1
129
    elif dims == 4:
130
        shape[0] = 2
131
    tensor = torch.randn(*shape)
132
    affine = np.eye(4)
133
    tempdir = Path(tempfile.gettempdir()) / '.torchio_tests'
134
    tempdir.mkdir(exist_ok=True)
135
    path = tempdir / 'test_io.nii'
136
    save_function = getattr(io, f'_write_{save_lib}')
137
    load_function = getattr(io, f'_read_{save_lib}')
138
    save_function(tensor, affine, path)
139
    loaded_tensor, loaded_affine = load_function(path)
140
    TorchioTestCase.assertTensorEqual(
141
        tensor.squeeze(), loaded_tensor.squeeze(),
142
        f'Save lib: {save_lib}; load lib: {load_lib}; dims: {dims}'
143
    )
144
    TorchioTestCase.assertTensorEqual(affine, loaded_affine)
145
146
147
class TestNibabelToSimpleITK(TorchioTestCase):
148
149
    def setUp(self):
150
        super().setUp()
151
        self.affine = np.eye(4)
152
153
    def test_wrong_num_dims(self):
154
        with self.assertRaises(ValueError):
155
            nib_to_sitk(np.random.rand(10, 10), self.affine)
156
157
    def test_2d_single(self):
158
        data = np.random.rand(1, 10, 12, 1)
159
        image = nib_to_sitk(data, self.affine)
160
        assert image.GetDimension() == 2
161
        assert image.GetSize() == (10, 12)
162
        assert image.GetNumberOfComponentsPerPixel() == 1
163
164
    def test_2d_multi(self):
165
        data = np.random.rand(5, 10, 12, 1)
166
        image = nib_to_sitk(data, self.affine)
167
        assert image.GetDimension() == 2
168
        assert image.GetSize() == (10, 12)
169
        assert image.GetNumberOfComponentsPerPixel() == 5
170
171
    def test_2d_3d_single(self):
172
        data = np.random.rand(1, 10, 12, 1)
173
        image = nib_to_sitk(data, self.affine, force_3d=True)
174
        assert image.GetDimension() == 3
175
        assert image.GetSize() == (10, 12, 1)
176
        assert image.GetNumberOfComponentsPerPixel() == 1
177
178
    def test_2d_3d_multi(self):
179
        data = np.random.rand(5, 10, 12, 1)
180
        image = nib_to_sitk(data, self.affine, force_3d=True)
181
        assert image.GetDimension() == 3
182
        assert image.GetSize() == (10, 12, 1)
183
        assert image.GetNumberOfComponentsPerPixel() == 5
184
185
    def test_3d_single(self):
186
        data = np.random.rand(1, 8, 10, 12)
187
        image = nib_to_sitk(data, self.affine)
188
        assert image.GetDimension() == 3
189
        assert image.GetSize() == (8, 10, 12)
190
        assert image.GetNumberOfComponentsPerPixel() == 1
191
192
    def test_3d_multi(self):
193
        data = np.random.rand(5, 8, 10, 12)
194
        image = nib_to_sitk(data, self.affine)
195
        assert image.GetDimension() == 3
196
        assert image.GetSize() == (8, 10, 12)
197
        assert image.GetNumberOfComponentsPerPixel() == 5
198