Passed
Push — master ( 53ab14...c2608f )
by Fernando
01:07
created

torchio.data.io._write_nibabel()   B

Complexity

Conditions 6

Size

Total Lines 25
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 19
nop 5
dl 0
loc 25
rs 8.5166
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Tuple
4
import torch
5
import numpy as np
6
import nibabel as nib
7
import SimpleITK as sitk
8
from .. import TypePath, TypeData
9
from ..utils import nib_to_sitk, sitk_to_nib
10
11
12
FLIPXY = np.diag([-1, -1, 1, 1])
13
14
15
def read_image(
16
        path: TypePath,
17
        itk_first: bool = False,
18
        ) -> Tuple[torch.Tensor, np.ndarray]:
19
    if itk_first:
20
        try:
21
            result = _read_sitk(path)
22
        except RuntimeError:  # try with NiBabel
23
            result = _read_nibabel(path)
24
    else:
25
        try:
26
            result = _read_nibabel(path)
27
        except nib.loadsave.ImageFileError:  # try with ITK
28
            result = _read_sitk(path)
29
    return result
30
31
32
def _read_nibabel(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
33
    img = nib.load(str(path), mmap=False)
34
    data = img.get_fdata(dtype=np.float32)
35
    tensor = torch.from_numpy(data)
36
    affine = img.affine
37
    return tensor, affine
38
39
40
def _read_sitk(
41
        path: TypePath,
42
        transpose_2d: bool = True,
43
        ) -> Tuple[torch.Tensor, np.ndarray]:
44
    if Path(path).is_dir():  # assume DICOM
45
        image = _read_dicom(path)
46
    else:
47
        image = sitk.ReadImage(str(path))
48
    if image.GetDimension() == 2 and transpose_2d:
49
        image = sitk.PermuteAxes(image, (1, 0))
50
    data, affine = sitk_to_nib(image, keepdim=True)
51
    if data.dtype != np.float32:
52
        data = data.astype(np.float32)
53
    tensor = torch.from_numpy(data)
54
    return tensor, affine
55
56
57
def _read_dicom(directory: TypePath):
58
    directory = Path(directory)
59
    if not directory.is_dir():  # unreachable if called from _read_sitk
60
        raise FileNotFoundError(f'Directory "{directory}" not found')
61
    reader = sitk.ImageSeriesReader()
62
    dicom_names = reader.GetGDCMSeriesFileNames(str(directory))
63
    if not dicom_names:
64
        message = (
65
            f'The directory "{directory}"'
66
            ' does not seem to contain DICOM files'
67
        )
68
        raise FileNotFoundError(message)
69
    reader.SetFileNames(dicom_names)
70
    image = reader.Execute()
71
    return image
72
73
74
def write_image(
75
        tensor: torch.Tensor,
76
        affine: TypeData,
77
        path: TypePath,
78
        itk_first: bool = False,
79
        squeeze: bool = True,
80
        channels_last: bool = True,
81
        ) -> None:
82
    args = tensor, affine, path
83
    kwargs = dict(squeeze=squeeze, channels_last=channels_last)
84
    if itk_first:
85
        try:
86
            _write_sitk(*args, squeeze=squeeze)
87
        except RuntimeError:  # try with NiBabel
88
            _write_nibabel(*args, squeeze=squeeze, channels_last=channels_last)
89
    else:
90
        try:
91
            _write_nibabel(*args, squeeze=squeeze, channels_last=channels_last)
92
        except nib.loadsave.ImageFileError:  # try with ITK
93
            _write_sitk(*args, squeeze=squeeze)
94
95
96
def _write_nibabel(
97
        tensor: TypeData,
98
        affine: TypeData,
99
        path: TypePath,
100
        squeeze: bool = True,
101
        channels_last: bool = True,
102
        ) -> None:
103
    """
104
    Expects a path with an extension that can be used by nibabel.save
105
    to write a NIfTI-1 image, such as '.nii.gz' or '.img'
106
    """
107
    assert tensor.ndim == 4
108
    if channels_last:
109
        tensor = tensor.permute(1, 2, 3, 0)
110
    tensor = tensor.squeeze() if squeeze else tensor
111
    suffix = Path(str(path).replace('.gz', '')).suffix
112
    if '.nii' in suffix:
113
        img = nib.Nifti1Image(np.asarray(tensor), affine)
114
    elif '.hdr' in suffix or '.img' in suffix:
115
        img = nib.Nifti1Pair(np.asarray(tensor), affine)
116
    else:
117
        raise nib.loadsave.ImageFileError
118
    img.header['qform_code'] = 1
119
    img.header['sform_code'] = 0
120
    img.to_filename(str(path))
121
122
123
def _write_sitk(
124
        tensor: torch.Tensor,
125
        affine: TypeData,
126
        path: TypePath,
127
        squeeze: bool = True,
128
        use_compression: bool = True,
129
        transpose_2d: bool = True,
130
        ) -> None:
131
    assert tensor.ndim == 4
132
    path = Path(path)
133
    if path.suffix in ('.png', '.jpg', '.jpeg'):
134
        warnings.warn(f'Casting to uint 8 before saving to {path}')
135
        tensor = tensor.numpy().astype(np.uint8)
136
    image = nib_to_sitk(tensor, affine, squeeze=squeeze)
137
    if image.GetDimension() == 2 and transpose_2d:
138
        image = sitk.PermuteAxes(image, (1, 0))
139
    sitk.WriteImage(image, str(path), use_compression)
140
141
142
def read_matrix(path: TypePath):
143
    """Read an affine transform and convert to tensor."""
144
    path = Path(path)
145
    suffix = path.suffix
146
    if suffix in ('.tfm', '.h5'):  # ITK
147
        tensor = _read_itk_matrix(path)
148
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
149
        tensor = _read_niftyreg_matrix(path)
150
    return tensor
0 ignored issues
show
introduced by
The variable tensor does not seem to be defined for all execution paths.
Loading history...
151
152
153
def write_matrix(matrix: torch.Tensor, path: TypePath):
154
    """Write an affine transform."""
155
    path = Path(path)
156
    suffix = path.suffix
157
    if suffix in ('.tfm', '.h5'):  # ITK
158
        _write_itk_matrix(matrix, path)
159
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
160
        _write_niftyreg_matrix(matrix, path)
161
162
163
def _to_itk_convention(matrix):
164
    """RAS to LPS"""
165
    matrix = np.dot(FLIPXY, matrix)
166
    matrix = np.dot(matrix, FLIPXY)
167
    matrix = np.linalg.inv(matrix)
168
    return matrix
169
170
171
def _from_itk_convention(matrix):
172
    """LPS to RAS"""
173
    matrix = np.dot(matrix, FLIPXY)
174
    matrix = np.dot(FLIPXY, matrix)
175
    matrix = np.linalg.inv(matrix)
176
    return matrix
177
178
179
def _read_itk_matrix(path):
180
    """Read an affine transform in ITK's .tfm format"""
181
    transform = sitk.ReadTransform(str(path))
182
    parameters = transform.GetParameters()
183
    rotation_parameters = parameters[:9]
184
    rotation_matrix = np.array(rotation_parameters).reshape(3, 3)
185
    translation_parameters = parameters[9:]
186
    translation_vector = np.array(translation_parameters).reshape(3, 1)
187
    matrix = np.hstack([rotation_matrix, translation_vector])
188
    homogeneous_matrix_lps = np.vstack([matrix, [0, 0, 0, 1]])
189
    homogeneous_matrix_ras = _from_itk_convention(homogeneous_matrix_lps)
190
    return torch.from_numpy(homogeneous_matrix_ras)
191
192
193
def _write_itk_matrix(matrix, tfm_path):
194
    """The tfm file contains the matrix from floating to reference."""
195
    transform = _matrix_to_itk_transform(matrix)
196
    transform.WriteTransform(str(tfm_path))
197
198
199
def _matrix_to_itk_transform(matrix, dimensions=3):
200
    matrix = _to_itk_convention(matrix)
201
    rotation = matrix[:dimensions, :dimensions].ravel().tolist()
202
    translation = matrix[:dimensions, 3].tolist()
203
    transform = sitk.AffineTransform(rotation, translation)
204
    return transform
205
206
207
def _read_niftyreg_matrix(trsf_path):
208
    """Read a NiftyReg matrix and return it as a NumPy array"""
209
    matrix = np.loadtxt(trsf_path)
210
    matrix = np.linalg.inv(matrix)
211
    return torch.from_numpy(matrix)
212
213
214
def _write_niftyreg_matrix(matrix, txt_path):
215
    """Write an affine transform in NiftyReg's .txt format (ref -> flo)"""
216
    matrix = np.linalg.inv(matrix)
217
    np.savetxt(txt_path, matrix, fmt='%.8f')
218