Passed
Push — master ( 0e3b0b...4497b8 )
by Fernando
01:17
created

torchio.data.io.nib_to_sitk()   B

Complexity

Conditions 8

Size

Total Lines 41
Code Lines 32

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 8
eloc 32
nop 5
dl 0
loc 41
rs 7.2453
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Tuple
4
5
import torch
6
import numpy as np
7
import nibabel as nib
8
import SimpleITK as sitk
9
10
from ..constants import REPO_URL
11
from ..typing import TypePath, TypeData
12
13
14
FLIPXY = np.diag([-1, -1, 1, 1])
15
16
17
def read_image(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
18
    try:
19
        result = _read_sitk(path)
20
    except RuntimeError:  # try with NiBabel
21
        try:
22
            result = _read_nibabel(path)
23
        except nib.loadsave.ImageFileError:
24
            raise RuntimeError(f'File "{path}" not understood')
25
    return result
26
27
28
def _read_nibabel(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
29
    img = nib.load(str(path), mmap=False)
30
    data = img.get_fdata(dtype=np.float32)
31
    if data.ndim == 5:
32
        data = data[..., 0, :]
33
        data = data.transpose(3, 0, 1, 2)
34
    data = check_uint_to_int(data)
35
    tensor = torch.from_numpy(data)
36
    affine = img.affine
37
    return tensor, affine
38
39
40
def _read_sitk(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
41
    if Path(path).is_dir():  # assume DICOM
42
        image = _read_dicom(path)
43
    else:
44
        image = sitk.ReadImage(str(path))
45
    data, affine = sitk_to_nib(image, keepdim=True)
46
    data = check_uint_to_int(data)
47
    tensor = torch.from_numpy(data)
48
    return tensor, affine
49
50
51
def _read_dicom(directory: TypePath):
52
    directory = Path(directory)
53
    if not directory.is_dir():  # unreachable if called from _read_sitk
54
        raise FileNotFoundError(f'Directory "{directory}" not found')
55
    reader = sitk.ImageSeriesReader()
56
    dicom_names = reader.GetGDCMSeriesFileNames(str(directory))
57
    if not dicom_names:
58
        message = (
59
            f'The directory "{directory}"'
60
            ' does not seem to contain DICOM files'
61
        )
62
        raise FileNotFoundError(message)
63
    reader.SetFileNames(dicom_names)
64
    image = reader.Execute()
65
    return image
66
67
68
def write_image(
69
        tensor: torch.Tensor,
70
        affine: TypeData,
71
        path: TypePath,
72
        squeeze: bool = True,
73
        ) -> None:
74
    args = tensor, affine, path
75
    try:
76
        _write_sitk(*args, squeeze=squeeze)
77
    except RuntimeError:  # try with NiBabel
78
        _write_nibabel(*args, squeeze=squeeze)
79
80
81
def _write_nibabel(
82
        tensor: TypeData,
83
        affine: TypeData,
84
        path: TypePath,
85
        squeeze: bool = False,
86
        ) -> None:
87
    """
88
    Expects a path with an extension that can be used by nibabel.save
89
    to write a NIfTI-1 image, such as '.nii.gz' or '.img'
90
    """
91
    assert tensor.ndim == 4
92
    num_components = tensor.shape[0]
93
94
    # NIfTI components must be at the end, in a 5D array
95
    if num_components == 1:
96
        tensor = tensor[0]
97
    else:
98
        tensor = tensor[np.newaxis].permute(2, 3, 4, 0, 1)
99
    tensor = tensor.squeeze() if squeeze else tensor
100
    suffix = Path(str(path).replace('.gz', '')).suffix
101
    if '.nii' in suffix:
102
        img = nib.Nifti1Image(np.asarray(tensor), affine)
103
    elif '.hdr' in suffix or '.img' in suffix:
104
        img = nib.Nifti1Pair(np.asarray(tensor), affine)
105
    else:
106
        raise nib.loadsave.ImageFileError
107
    if num_components > 1:
108
        img.header.set_intent('vector')
109
    img.header['qform_code'] = 1
110
    img.header['sform_code'] = 0
111
    nib.save(img, str(path))
112
113
114
def _write_sitk(
115
        tensor: torch.Tensor,
116
        affine: TypeData,
117
        path: TypePath,
118
        squeeze: bool = True,
119
        use_compression: bool = True,
120
        ) -> None:
121
    assert tensor.ndim == 4
122
    path = Path(path)
123
    if path.suffix in ('.png', '.jpg', '.jpeg'):
124
        warnings.warn(
125
            f'Casting to uint 8 before saving to {path}',
126
            RuntimeWarning,
127
        )
128
        tensor = tensor.numpy().astype(np.uint8)
129
    image = nib_to_sitk(tensor, affine, squeeze=squeeze)
130
    sitk.WriteImage(image, str(path), use_compression)
131
132
133
def read_matrix(path: TypePath):
134
    """Read an affine transform and convert to tensor."""
135
    path = Path(path)
136
    suffix = path.suffix
137
    if suffix in ('.tfm', '.h5'):  # ITK
138
        tensor = _read_itk_matrix(path)
139
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
140
        tensor = _read_niftyreg_matrix(path)
141
    else:
142
        raise ValueError(f'Unknown suffix for transform file: "{suffix}"')
143
    return tensor
144
145
146
def write_matrix(matrix: torch.Tensor, path: TypePath):
147
    """Write an affine transform."""
148
    path = Path(path)
149
    suffix = path.suffix
150
    if suffix in ('.tfm', '.h5'):  # ITK
151
        _write_itk_matrix(matrix, path)
152
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
153
        _write_niftyreg_matrix(matrix, path)
154
155
156
def _to_itk_convention(matrix):
157
    """RAS to LPS"""
158
    matrix = np.dot(FLIPXY, matrix)
159
    matrix = np.dot(matrix, FLIPXY)
160
    matrix = np.linalg.inv(matrix)
161
    return matrix
162
163
164
def _from_itk_convention(matrix):
165
    """LPS to RAS"""
166
    matrix = np.dot(matrix, FLIPXY)
167
    matrix = np.dot(FLIPXY, matrix)
168
    matrix = np.linalg.inv(matrix)
169
    return matrix
170
171
172
def _read_itk_matrix(path):
173
    """Read an affine transform in ITK's .tfm format"""
174
    transform = sitk.ReadTransform(str(path))
175
    parameters = transform.GetParameters()
176
    rotation_parameters = parameters[:9]
177
    rotation_matrix = np.array(rotation_parameters).reshape(3, 3)
178
    translation_parameters = parameters[9:]
179
    translation_vector = np.array(translation_parameters).reshape(3, 1)
180
    matrix = np.hstack([rotation_matrix, translation_vector])
181
    homogeneous_matrix_lps = np.vstack([matrix, [0, 0, 0, 1]])
182
    homogeneous_matrix_ras = _from_itk_convention(homogeneous_matrix_lps)
183
    return torch.from_numpy(homogeneous_matrix_ras)
184
185
186
def _write_itk_matrix(matrix, tfm_path):
187
    """The tfm file contains the matrix from floating to reference."""
188
    transform = _matrix_to_itk_transform(matrix)
189
    transform.WriteTransform(str(tfm_path))
190
191
192
def _matrix_to_itk_transform(matrix, dimensions=3):
193
    matrix = _to_itk_convention(matrix)
194
    rotation = matrix[:dimensions, :dimensions].ravel().tolist()
195
    translation = matrix[:dimensions, 3].tolist()
196
    transform = sitk.AffineTransform(rotation, translation)
197
    return transform
198
199
200
def _read_niftyreg_matrix(trsf_path):
201
    """Read a NiftyReg matrix and return it as a NumPy array"""
202
    matrix = np.loadtxt(trsf_path)
203
    matrix = np.linalg.inv(matrix)
204
    return torch.from_numpy(matrix)
205
206
207
def _write_niftyreg_matrix(matrix, txt_path):
208
    """Write an affine transform in NiftyReg's .txt format (ref -> flo)"""
209
    matrix = np.linalg.inv(matrix)
210
    np.savetxt(txt_path, matrix, fmt='%.8f')
211
212
213
def get_rotation_and_spacing_from_affine(
214
        affine: np.ndarray,
215
        ) -> Tuple[np.ndarray, np.ndarray]:
216
    # From https://github.com/nipy/nibabel/blob/master/nibabel/orientations.py
217
    rotation_zoom = affine[:3, :3]
218
    spacing = np.sqrt(np.sum(rotation_zoom * rotation_zoom, axis=0))
219
    rotation = rotation_zoom / spacing
220
    return rotation, spacing
221
222
223
def nib_to_sitk(
224
        data: TypeData,
225
        affine: TypeData,
226
        squeeze: bool = False,
227
        force_3d: bool = False,
228
        force_4d: bool = False,
229
        ) -> sitk.Image:
230
    """Create a SimpleITK image from a tensor and a 4x4 affine matrix."""
231
    if data.ndim != 4:
232
        raise ValueError(f'Input must be 4D, but has shape {tuple(data.shape)}')
233
    # Possibilities
234
    # (1, w, h, 1)
235
    # (c, w, h, 1)
236
    # (1, w, h, 1)
237
    # (c, w, h, d)
238
    array = np.asarray(data)
239
    affine = np.asarray(affine).astype(np.float64)
240
241
    is_multichannel = array.shape[0] > 1 and not force_4d
242
    is_2d = array.shape[3] == 1 and not force_3d
243
    if is_2d:
244
        array = array[..., 0]
245
    if not is_multichannel and not force_4d:
246
        array = array[0]
247
    array = array.transpose()  # (W, H, D, C) or (W, H, D)
248
    image = sitk.GetImageFromArray(array, isVector=is_multichannel)
249
250
    rotation, spacing = get_rotation_and_spacing_from_affine(affine)
251
    flip_xy = np.diag((-1, -1, 1))  # used to switch between LPS and RAS
252
    origin = np.dot(flip_xy, affine[:3, 3])
253
    direction = np.dot(flip_xy, rotation)
254
    if is_2d:  # ignore first dimension if 2D (1, W, H, 1)
255
        direction = direction[:2, :2]
256
    image.SetOrigin(origin)  # should I add a 4th value if force_4d?
257
    image.SetSpacing(spacing)
258
    image.SetDirection(direction.flatten())
259
    if data.ndim == 4:
260
        assert image.GetNumberOfComponentsPerPixel() == data.shape[0]
261
    num_spatial_dims = 2 if is_2d else 3
262
    assert image.GetSize() == data.shape[1: 1 + num_spatial_dims]
263
    return image
264
265
266
def sitk_to_nib(
267
        image: sitk.Image,
268
        keepdim: bool = False,
269
        ) -> Tuple[np.ndarray, np.ndarray]:
270
    data = sitk.GetArrayFromImage(image).transpose()
271
    num_components = image.GetNumberOfComponentsPerPixel()
272
    if num_components == 1:
273
        data = data[np.newaxis]  # add channels dimension
274
    input_spatial_dims = image.GetDimension()
275
    if input_spatial_dims == 2:
276
        data = data[..., np.newaxis]
277
    if not keepdim:
278
        data = ensure_4d(data, num_spatial_dims=input_spatial_dims)
279
    assert data.shape[0] == num_components
280
    assert data.shape[1: 1 + input_spatial_dims] == image.GetSize()
281
    spacing = np.array(image.GetSpacing())
282
    direction = np.array(image.GetDirection())
283
    origin = image.GetOrigin()
284
    if len(direction) == 9:
285
        rotation = direction.reshape(3, 3)
286
    elif len(direction) == 4:  # ignore first dimension if 2D (1, W, H, 1)
287
        rotation_2d = direction.reshape(2, 2)
288
        rotation = np.eye(3)
289
        rotation[:2, :2] = rotation_2d
290
        spacing = *spacing, 1
291
        origin = *origin, 0
292
    else:
293
        raise RuntimeError(f'Direction not understood: {direction}')
294
    flip_xy = np.diag((-1, -1, 1))  # used to switch between LPS and RAS
295
    rotation = np.dot(flip_xy, rotation)
296
    rotation_zoom = rotation * spacing
297
    translation = np.dot(flip_xy, origin)
298
    affine = np.eye(4)
299
    affine[:3, :3] = rotation_zoom
300
    affine[:3, 3] = translation
301
    return data, affine
302
303
304
def ensure_4d(tensor: TypeData, num_spatial_dims=None) -> TypeData:
305
    # I wish named tensors were properly supported in PyTorch
306
    num_dimensions = tensor.ndim
307
    if num_dimensions == 4:
308
        pass
309
    elif num_dimensions == 5:  # hope (W, H, D, 1, C)
310
        if tensor.shape[-2] == 1:
311
            tensor = tensor[..., 0, :]
312
            tensor = tensor.permute(3, 0, 1, 2)
313
        else:
314
            raise ValueError('5D is not supported for shape[-2] > 1')
315
    elif num_dimensions == 2:  # assume 2D monochannel (W, H)
316
        tensor = tensor[np.newaxis, ..., np.newaxis]  # (1, W, H, 1)
317
    elif num_dimensions == 3:  # 2D multichannel or 3D monochannel?
318
        if num_spatial_dims == 2:
319
            tensor = tensor[..., np.newaxis]  # (C, W, H, 1)
320
        elif num_spatial_dims == 3:  # (W, H, D)
321
            tensor = tensor[np.newaxis]  # (1, W, H, D)
322
        else:  # try to guess
323
            shape = tensor.shape
324
            maybe_rgb = 3 in (shape[0], shape[-1])
325
            if maybe_rgb:
326
                if shape[-1] == 3:  # (W, H, 3)
327
                    tensor = tensor.permute(2, 0, 1)  # (3, W, H)
328
                tensor = tensor[..., np.newaxis]  # (3, W, H, 1)
329
            else:  # (W, H, D)
330
                tensor = tensor[np.newaxis]  # (1, W, H, D)
331
    else:
332
        message = (
333
            f'{num_dimensions}D images not supported yet. Please create an'
334
            f' issue in {REPO_URL} if you would like support for them'
335
        )
336
        raise ValueError(message)
337
    assert tensor.ndim == 4
338
    return tensor
339
340
341
def check_uint_to_int(array):
342
    # This is because PyTorch won't take uint16 nor uint32
343
    if array.dtype == np.uint16:
344
        return array.astype(np.int32)
345
    if array.dtype == np.uint32:
346
        return array.astype(np.int64)
347
    return array
348