Passed
Push — master ( 85bce8...47d3da )
by Fernando
01:13
created

tests.data.test_io.TestIO.save_load_save_load()   A

Complexity

Conditions 2

Size

Total Lines 16
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 16
nop 2
dl 0
loc 16
rs 9.6
c 0
b 0
f 0
1
import tempfile
2
import unittest
3
from pathlib import Path
4
import torch
5
import pytest
6
import numpy as np
7
from numpy.testing import assert_array_equal
8
import nibabel as nib
9
import SimpleITK as sitk
10
from ..utils import TorchioTestCase
11
from torchio.data import io, ScalarImage
12
13
14
class TestIO(TorchioTestCase):
15
    """Tests for `io` module."""
16
    def setUp(self):
17
        super().setUp()
18
        self.write_dicom()
19
        string = (
20
            '1.5 0.18088 -0.124887 0.65072 '
21
            '-0.20025 0.965639 -0.165653 -11.6452 '
22
            '0.0906326 0.18661 0.978245 11.4002 '
23
            '0 0 0 1 '
24
        )
25
        tensor = torch.from_numpy(np.fromstring(string, sep=' ').reshape(4, 4))
26
        self.matrix = tensor
27
28
    def write_dicom(self):
29
        self.dicom_dir = self.dir / 'dicom'
30
        self.dicom_dir.mkdir(exist_ok=True)
31
        self.dicom_path = self.dicom_dir / 'dicom.dcm'
32
        self.nii_path = self.get_image_path('read_image')
33
        writer = sitk.ImageFileWriter()
34
        writer.SetFileName(str(self.dicom_path))
35
        image = sitk.ReadImage(str(self.nii_path))
36
        image = sitk.Cast(image, sitk.sitkUInt16)
37
        image = image[0]  # dicom reader supports 2D only
38
        writer.Execute(image)
39
40
    def test_read_image(self):
41
        # I need to find something readable by nib but not sitk
42
        io.read_image(self.nii_path)
43
44
    def test_save_rgb(self):
45
        im = ScalarImage(tensor=torch.rand(1, 4, 5, 1))
46
        with self.assertWarns(UserWarning):
47
            im.save(self.dir / 'test.jpg')
48
49
    def test_read_dicom_file(self):
50
        io.read_image(self.dicom_path)
51
52
    def test_read_dicom_dir(self):
53
        io.read_image(self.dicom_dir)
54
55
    def test_dicom_dir_missing(self):
56
        with self.assertRaises(FileNotFoundError):
57
            io._read_dicom('missing')
58
59
    def test_dicom_dir_no_files(self):
60
        empty = self.dir / 'empty'
61
        empty.mkdir()
62
        with self.assertRaises(FileNotFoundError):
63
            io._read_dicom(empty)
64
65
    def write_read_matrix(self, suffix):
66
        out_path = self.dir / f'matrix{suffix}'
67
        io.write_matrix(self.matrix, out_path)
68
        matrix = io.read_matrix(out_path)
69
        assert torch.allclose(matrix, self.matrix)
70
71
    def test_matrix_itk(self):
72
        self.write_read_matrix('.tfm')
73
        self.write_read_matrix('.h5')
74
75
    def test_matrix_txt(self):
76
        self.write_read_matrix('.txt')
77
78
79
# This doesn't work as a method of the class
80
libs = 'sitk', 'nibabel'
81
parameters = []
82
for save_lib in libs:
83
    for load_lib in libs:
84
        for dims in 2, 3, 4:
85
            parameters.append((save_lib, load_lib, dims))
86
87
88
@pytest.mark.parametrize(('save_lib', 'load_lib', 'dims'), parameters)
89
def test_write_nd_with_a_read_it_with_b(save_lib, load_lib, dims):
90
    shape = [1, 4, 5, 6]
91
    if dims == 2:
92
        shape[-1] = 1
93
    elif dims == 4:
94
        shape[0] = 2
95
    tensor = torch.randn(*shape)
96
    affine = np.eye(4)
97
    tempdir = Path(tempfile.gettempdir()) / '.torchio_tests'
98
    tempdir.mkdir(exist_ok=True)
99
    path = tempdir / 'test_io.nii'
100
    save_function = getattr(io, f'_write_{save_lib}')
101
    load_function = getattr(io, f'_read_{save_lib}')
102
    save_function(tensor, affine, path)
103
    loaded_tensor, loaded_affine = load_function(path)
104
    assert_array_equal(
105
        tensor.squeeze(), loaded_tensor.squeeze(),
106
        f'Save lib: {save_lib}; load lib: {load_lib}; dims: {dims}'
107
    )
108
    assert_array_equal(affine, loaded_affine)
109