Passed
Push — master ( 9ae791...b80c39 )
by Fernando
01:28
created

torchio.data.io._to_itk_convention()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 1
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
from pathlib import Path
2
from typing import Tuple
3
import torch
4
import numpy as np
5
import nibabel as nib
6
import SimpleITK as sitk
7
from .. import TypePath, TypeData
8
from ..utils import nib_to_sitk, sitk_to_nib
9
10
11
FLIPXY = np.diag([-1, -1, 1, 1])
12
13
14
def read_image(
15
        path: TypePath,
16
        itk_first: bool = False,
17
        ) -> Tuple[torch.Tensor, np.ndarray]:
18
    if itk_first:
19
        try:
20
            result = _read_sitk(path)
21
        except RuntimeError:  # try with NiBabel
22
            result = _read_nibabel(path)
23
    else:
24
        try:
25
            result = _read_nibabel(path)
26
        except nib.loadsave.ImageFileError:  # try with ITK
27
            result = _read_sitk(path)
28
    return result
29
30
31
def _read_nibabel(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
32
    nii = nib.load(str(path), mmap=False)
33
    data = nii.get_fdata(dtype=np.float32)
34
    tensor = torch.from_numpy(data)
35
    affine = nii.affine
36
    return tensor, affine
37
38
39
def _read_sitk(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
40
    if Path(path).is_dir():  # assume DICOM
41
        image = _read_dicom(path)
42
    else:
43
        image = sitk.ReadImage(str(path))
44
    data, affine = sitk_to_nib(image)
45
    if data.dtype != np.float32:
46
        data = data.astype(np.float32)
47
    tensor = torch.from_numpy(data)
48
    return tensor, affine
49
50
51
def _read_dicom(directory: TypePath):
52
    directory = Path(directory)
53
    if not directory.is_dir():  # unreachable if called from _read_sitk
54
        raise FileNotFoundError(f'Directory "{directory}" not found')
55
    reader = sitk.ImageSeriesReader()
56
    dicom_names = reader.GetGDCMSeriesFileNames(str(directory))
57
    if not dicom_names:
58
        message = (
59
            f'The directory "{directory}"'
60
            ' does not seem to contain DICOM files'
61
        )
62
        raise FileNotFoundError(message)
63
    reader.SetFileNames(dicom_names)
64
    image = reader.Execute()
65
    return image
66
67
68
def write_image(
69
        tensor: torch.Tensor,
70
        affine: TypeData,
71
        path: TypePath,
72
        itk_first: bool = False,
73
        ) -> None:
74
    if itk_first:
75
        try:
76
            _write_sitk(tensor, affine, path)
77
        except RuntimeError:  # try with NiBabel
78
            _write_nibabel(tensor, affine, path)
79
    else:
80
        try:
81
            _write_nibabel(tensor, affine, path)
82
        except nib.loadsave.ImageFileError:  # try with ITK
83
            _write_sitk(tensor, affine, path)
84
85
86
def _write_nibabel(
87
        tensor: torch.Tensor,
88
        affine: TypeData,
89
        path: TypePath,
90
        ) -> None:
91
    """
92
    Expects a path with an extension that can be used by nibabel.save
93
    to write a NIfTI-1 image, such as '.nii.gz' or '.img'
94
    """
95
    nii = nib.Nifti1Image(tensor.numpy(), affine)
96
    nii.header['qform_code'] = 1
97
    nii.header['sform_code'] = 0
98
    nii.to_filename(str(path))
99
100
101
def _write_sitk(
102
        tensor: torch.Tensor,
103
        affine: TypeData,
104
        path: TypePath,
105
        ) -> None:
106
    image = nib_to_sitk(tensor, affine)
107
    sitk.WriteImage(image, str(path))
108
109
110
def read_matrix(path: TypePath):
111
    """Read an affine transform and convert to tensor."""
112
    path = Path(path)
113
    suffix = path.suffix
114
    if suffix in ('.tfm', '.h5'):  # ITK
115
        tensor = _read_itk_matrix(path)
116
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
117
        tensor = _read_niftyreg_matrix(path)
118
    return tensor
0 ignored issues
show
introduced by
The variable tensor does not seem to be defined for all execution paths.
Loading history...
119
120
121
def write_matrix(matrix: torch.Tensor, path: TypePath):
122
    """Write an affine transform."""
123
    path = Path(path)
124
    suffix = path.suffix
125
    if suffix in ('.tfm', '.h5'):  # ITK
126
        _write_itk_matrix(matrix, path)
127
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
128
        _write_niftyreg_matrix(matrix, path)
129
130
131
def _to_itk_convention(matrix):
132
    """RAS to LPS"""
133
    matrix = np.dot(FLIPXY, matrix)
134
    matrix = np.dot(matrix, FLIPXY)
135
    matrix = np.linalg.inv(matrix)
136
    return matrix
137
138
139
def _from_itk_convention(matrix):
140
    """LPS to RAS"""
141
    matrix = np.dot(matrix, FLIPXY)
142
    matrix = np.dot(FLIPXY, matrix)
143
    matrix = np.linalg.inv(matrix)
144
    return matrix
145
146
147
def _read_itk_matrix(path):
148
    """Read an affine transform in ITK's .tfm format"""
149
    transform = sitk.ReadTransform(str(path))
150
    parameters = transform.GetParameters()
151
    rotation_parameters = parameters[:9]
152
    rotation_matrix = np.array(rotation_parameters).reshape(3, 3)
153
    translation_parameters = parameters[9:]
154
    translation_vector = np.array(translation_parameters).reshape(3, 1)
155
    matrix = np.hstack([rotation_matrix, translation_vector])
156
    homogeneous_matrix_lps = np.vstack([matrix, [0, 0, 0, 1]])
157
    homogeneous_matrix_ras = _from_itk_convention(homogeneous_matrix_lps)
158
    return torch.from_numpy(homogeneous_matrix_ras)
159
160
161
def _write_itk_matrix(matrix, tfm_path):
162
    """The tfm file contains the matrix from floating to reference."""
163
    transform = _matrix_to_itk_transform(matrix)
164
    transform.WriteTransform(str(tfm_path))
165
166
167
def _matrix_to_itk_transform(matrix, dimensions=3):
168
    matrix = _to_itk_convention(matrix)
169
    rotation = matrix[:dimensions, :dimensions].ravel().tolist()
170
    translation = matrix[:dimensions, 3].tolist()
171
    transform = sitk.AffineTransform(rotation, translation)
172
    return transform
173
174
175
def _read_niftyreg_matrix(trsf_path):
176
    """Read a NiftyReg matrix and return it as a NumPy array"""
177
    matrix = np.loadtxt(trsf_path)
178
    matrix = np.linalg.inv(matrix)
179
    return torch.from_numpy(matrix)
180
181
182
def _write_niftyreg_matrix(matrix, txt_path):
183
    """Write an affine transform in NiftyReg's .txt format (ref -> flo)"""
184
    matrix = np.linalg.inv(matrix)
185
    np.savetxt(txt_path, matrix, fmt='%.8f')
186