Passed
Pull Request — master (#332)
by Fernando
04:34
created

tests.data.test_io.TestIO.test_read_dicom_dir()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import tempfile
2
import unittest
3
from pathlib import Path
4
import torch
5
import pytest
6
import numpy as np
7
from numpy.testing import assert_array_equal
8
import nibabel as nib
9
import SimpleITK as sitk
10
from ..utils import TorchioTestCase
11
from torchio.data import io, ScalarImage
12
13
14
class TestIO(TorchioTestCase):
15
    """Tests for `io` module."""
16
    def setUp(self):
17
        super().setUp()
18
        self.nii_path = self.get_image_path('read_image')
19
        self.dicom_dir = self.get_tests_data_dir() / 'dicom'
20
        self.dicom_path = self.dicom_dir / 'IMG0001.dcm'
21
        string = (
22
            '1.5 0.18088 -0.124887 0.65072 '
23
            '-0.20025 0.965639 -0.165653 -11.6452 '
24
            '0.0906326 0.18661 0.978245 11.4002 '
25
            '0 0 0 1 '
26
        )
27
        tensor = torch.from_numpy(np.fromstring(string, sep=' ').reshape(4, 4))
28
        self.matrix = tensor
29
30
    def test_read_image(self):
31
        # I need to find something readable by nib but not sitk
32
        io.read_image(self.nii_path)
33
34
    def test_save_rgb(self):
35
        im = ScalarImage(tensor=torch.rand(1, 4, 5, 1))
36
        with self.assertWarns(UserWarning):
37
            im.save(self.dir / 'test.jpg')
38
39
    def test_read_dicom_file(self):
40
        tensor, _ = io.read_image(self.dicom_path)
41
        self.assertEqual(tuple(tensor.shape), (1, 88, 128, 1))
42
43
    def test_read_dicom_dir(self):
44
        tensor, _ = io.read_image(self.dicom_dir)
45
        self.assertEqual(tuple(tensor.shape), (1, 88, 128, 17))
46
47
    def test_dicom_dir_missing(self):
48
        with self.assertRaises(FileNotFoundError):
49
            io._read_dicom('missing')
50
51
    def test_dicom_dir_no_files(self):
52
        empty = self.dir / 'empty'
53
        empty.mkdir()
54
        with self.assertRaises(FileNotFoundError):
55
            io._read_dicom(empty)
56
57
    def write_read_matrix(self, suffix):
58
        out_path = self.dir / f'matrix{suffix}'
59
        io.write_matrix(self.matrix, out_path)
60
        matrix = io.read_matrix(out_path)
61
        assert torch.allclose(matrix, self.matrix)
62
63
    def test_matrix_itk(self):
64
        self.write_read_matrix('.tfm')
65
        self.write_read_matrix('.h5')
66
67
    def test_matrix_txt(self):
68
        self.write_read_matrix('.txt')
69
70
71
# This doesn't work as a method of the class
72
libs = 'sitk', 'nibabel'
73
parameters = []
74
for save_lib in libs:
75
    for load_lib in libs:
76
        for dims in 2, 3, 4:
77
            parameters.append((save_lib, load_lib, dims))
78
79
80
@pytest.mark.parametrize(('save_lib', 'load_lib', 'dims'), parameters)
81
def test_write_nd_with_a_read_it_with_b(save_lib, load_lib, dims):
82
    shape = [1, 4, 5, 6]
83
    if dims == 2:
84
        shape[-1] = 1
85
    elif dims == 4:
86
        shape[0] = 2
87
    tensor = torch.randn(*shape)
88
    affine = np.eye(4)
89
    tempdir = Path(tempfile.gettempdir()) / '.torchio_tests'
90
    tempdir.mkdir(exist_ok=True)
91
    path = tempdir / 'test_io.nii'
92
    save_function = getattr(io, f'_write_{save_lib}')
93
    load_function = getattr(io, f'_read_{save_lib}')
94
    save_function(tensor, affine, path)
95
    loaded_tensor, loaded_affine = load_function(path)
96
    assert_array_equal(
97
        tensor.squeeze(), loaded_tensor.squeeze(),
98
        f'Save lib: {save_lib}; load lib: {load_lib}; dims: {dims}'
99
    )
100
    assert_array_equal(affine, loaded_affine)
101