Passed
Pull Request — master (#286)
by Fernando
01:09
created

tests.data.test_io.TestIO.test_nd()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 1
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
import tempfile
2
import unittest
3
from pathlib import Path
4
import torch
5
import pytest
6
import numpy as np
7
from numpy.testing import assert_array_equal
8
import nibabel as nib
9
import SimpleITK as sitk
10
from ..utils import TorchioTestCase
11
from torchio.data import io, ScalarImage
12
13
14
class TestIO(TorchioTestCase):
15
    """Tests for `io` module."""
16
    def setUp(self):
17
        super().setUp()
18
        self.write_dicom()
19
        string = (
20
            '1.5 0.18088 -0.124887 0.65072 '
21
            '-0.20025 0.965639 -0.165653 -11.6452 '
22
            '0.0906326 0.18661 0.978245 11.4002 '
23
            '0 0 0 1 '
24
        )
25
        tensor = torch.from_numpy(np.fromstring(string, sep=' ').reshape(4, 4))
26
        self.matrix = tensor
27
28
    def write_dicom(self):
29
        self.dicom_dir = self.dir / 'dicom'
30
        self.dicom_dir.mkdir(exist_ok=True)
31
        self.dicom_path = self.dicom_dir / 'dicom.dcm'
32
        self.nii_path = self.get_image_path('read_image')
33
        writer = sitk.ImageFileWriter()
34
        writer.SetFileName(str(self.dicom_path))
35
        image = sitk.ReadImage(str(self.nii_path))
36
        image = sitk.Cast(image, sitk.sitkUInt16)
37
        image = image[0]  # dicom reader supports 2D only
38
        writer.Execute(image)
39
40
    def test_read_image(self):
41
        # I need to find something readable by nib but not sitk (MINC?)
42
        io.read_image(self.nii_path)
43
44
    def test_read_dicom_file(self):
45
        io.read_image(self.dicom_path)
46
47
    def test_read_dicom_dir(self):
48
        io.read_image(self.dicom_dir)
49
50
    def test_dicom_dir_missing(self):
51
        with self.assertRaises(FileNotFoundError):
52
            io._read_dicom('missing')
53
54
    def test_dicom_dir_no_files(self):
55
        empty = self.dir / 'empty'
56
        empty.mkdir()
57
        with self.assertRaises(FileNotFoundError):
58
            io._read_dicom(empty)
59
60
    def write_read_matrix(self, suffix):
61
        out_path = self.dir / f'matrix{suffix}'
62
        io.write_matrix(self.matrix, out_path)
63
        matrix = io.read_matrix(out_path)
64
        assert torch.allclose(matrix, self.matrix)
65
66
    def test_matrix_itk(self):
67
        self.write_read_matrix('.tfm')
68
        self.write_read_matrix('.h5')
69
70
    def test_matrix_txt(self):
71
        self.write_read_matrix('.txt')
72
73
74
# This doesn't work as a method of the class
75
libs = 'sitk', 'nibabel'
76
parameters = []
77
for save_lib in libs:
78
    for load_lib in libs:
79
        for dims in 2, 3, 4:
80
            parameters.append((save_lib, load_lib, dims))
81
82
83
@pytest.mark.parametrize(('save_lib', 'load_lib', 'dims'), parameters)
84
def test_write_nd_with_a_read_it_with_b(save_lib, load_lib, dims):
85
    shape = [1, 4, 5, 6]
86
    if dims == 2:
87
        shape[-1] = 1
88
    elif dims == 4:
89
        shape[0] = 2
90
    tensor = torch.randn(*shape)
91
    affine = np.eye(4)
92
    tempdir = Path(tempfile.gettempdir()) / '.torchio_tests'
93
    tempdir.mkdir(exist_ok=True)
94
    path = tempdir / 'test_io.nii'
95
    save_function = getattr(io, f'_write_{save_lib}')
96
    load_function = getattr(io, f'_read_{save_lib}')
97
    save_function(tensor, affine, path)
98
    loaded_tensor, loaded_affine = load_function(path)
99
    assert_array_equal(
100
        tensor.squeeze(), loaded_tensor.squeeze(),
101
        f'Save lib: {save_lib}; load lib: {load_lib}; dims: {dims}'
102
    )
103
    assert_array_equal(affine, loaded_affine)
104