Passed
Push — master ( 0bf8ef...e85db2 )
by Fernando
01:12
created

torchio.data.io.get_ras_affine_from_sitk()   A

Complexity

Conditions 3

Size

Total Lines 21
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 21
rs 9.4
c 0
b 0
f 0
cc 3
nop 1
1
import warnings
2
from pathlib import Path
3
from typing import Tuple, Union
4
5
import torch
6
import numpy as np
7
import nibabel as nib
8
import SimpleITK as sitk
9
10
from ..constants import REPO_URL
11
from ..typing import TypePath, TypeData, TypeTripletFloat, TypeDirection
12
13
14
# Matrice used to switch between LPS and RAS
15
FLIPXY_33 = np.diag([-1, -1, 1])
16
FLIPXY_44 = np.diag([-1, -1, 1, 1])
17
18
19
def read_image(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
20
    try:
21
        result = _read_sitk(path)
22
    except RuntimeError:  # try with NiBabel
23
        try:
24
            result = _read_nibabel(path)
25
        except nib.loadsave.ImageFileError:
26
            message = (
27
                f'File "{path}" not understood.'
28
                ' Check supported formats by at'
29
                ' https://simpleitk.readthedocs.io/en/master/IO.html#images'
30
                ' and https://nipy.org/nibabel/api.html#file-formats'
31
            )
32
            raise RuntimeError(message)
33
    return result
34
35
36
def _read_nibabel(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
37
    img = nib.load(str(path), mmap=False)
38
    data = img.get_fdata(dtype=np.float32)
39
    if data.ndim == 5:
40
        data = data[..., 0, :]
41
        data = data.transpose(3, 0, 1, 2)
42
    data = check_uint_to_int(data)
43
    tensor = torch.as_tensor(data)
44
    affine = img.affine
45
    return tensor, affine
46
47
48
def _read_sitk(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
49
    if Path(path).is_dir():  # assume DICOM
50
        image = _read_dicom(path)
51
    else:
52
        image = sitk.ReadImage(str(path))
53
    data, affine = sitk_to_nib(image, keepdim=True)
54
    data = check_uint_to_int(data)
55
    tensor = torch.as_tensor(data)
56
    return tensor, affine
57
58
59
def _read_dicom(directory: TypePath):
60
    directory = Path(directory)
61
    if not directory.is_dir():  # unreachable if called from _read_sitk
62
        raise FileNotFoundError(f'Directory "{directory}" not found')
63
    reader = sitk.ImageSeriesReader()
64
    dicom_names = reader.GetGDCMSeriesFileNames(str(directory))
65
    if not dicom_names:
66
        message = (
67
            f'The directory "{directory}"'
68
            ' does not seem to contain DICOM files'
69
        )
70
        raise FileNotFoundError(message)
71
    reader.SetFileNames(dicom_names)
72
    image = reader.Execute()
73
    return image
74
75
76
def read_shape(path: TypePath) -> Tuple[int, int, int, int]:
77
    reader = sitk.ImageFileReader()
78
    reader.SetFileName(str(path))
79
    reader.ReadImageInformation()
80
    num_channels = reader.GetNumberOfComponents()
81
    spatial_shape = reader.GetSize()
82
    if reader.GetDimension() == 2:
83
        spatial_shape = *spatial_shape, 1
84
    return (num_channels,) + spatial_shape
85
86
87
def read_affine(path: TypePath) -> np.ndarray:
88
    reader = get_reader(path)
89
    affine = get_ras_affine_from_sitk(reader)
90
    return affine
91
92
93
def get_reader(path: TypePath, read: bool = True) -> sitk.ImageFileReader:
94
    reader = sitk.ImageFileReader()
95
    reader.SetFileName(str(path))
96
    if read:
97
        reader.ReadImageInformation()
98
    return reader
99
100
101
def write_image(
102
        tensor: torch.Tensor,
103
        affine: TypeData,
104
        path: TypePath,
105
        squeeze: bool = True,
106
        ) -> None:
107
    args = tensor, affine, path
108
    try:
109
        _write_sitk(*args)
110
    except RuntimeError:  # try with NiBabel
111
        _write_nibabel(*args, squeeze=squeeze)
112
113
114
def _write_nibabel(
115
        tensor: TypeData,
116
        affine: TypeData,
117
        path: TypePath,
118
        squeeze: bool = False,
119
        ) -> None:
120
    """
121
    Expects a path with an extension that can be used by nibabel.save
122
    to write a NIfTI-1 image, such as '.nii.gz' or '.img'
123
    """
124
    assert tensor.ndim == 4
125
    num_components = tensor.shape[0]
126
127
    # NIfTI components must be at the end, in a 5D array
128
    if num_components == 1:
129
        tensor = tensor[0]
130
    else:
131
        tensor = tensor[np.newaxis].permute(2, 3, 4, 0, 1)
132
    tensor = tensor.squeeze() if squeeze else tensor
133
    suffix = Path(str(path).replace('.gz', '')).suffix
134
    if '.nii' in suffix:
135
        img = nib.Nifti1Image(np.asarray(tensor), affine)
136
    elif '.hdr' in suffix or '.img' in suffix:
137
        img = nib.Nifti1Pair(np.asarray(tensor), affine)
138
    else:
139
        raise nib.loadsave.ImageFileError
140
    if num_components > 1:
141
        img.header.set_intent('vector')
142
    img.header['qform_code'] = 1
143
    img.header['sform_code'] = 0
144
    nib.save(img, str(path))
145
146
147
def _write_sitk(
148
        tensor: torch.Tensor,
149
        affine: TypeData,
150
        path: TypePath,
151
        use_compression: bool = True,
152
        ) -> None:
153
    assert tensor.ndim == 4
154
    path = Path(path)
155
    if path.suffix in ('.png', '.jpg', '.jpeg'):
156
        warnings.warn(
157
            f'Casting to uint 8 before saving to {path}',
158
            RuntimeWarning,
159
        )
160
        tensor = tensor.numpy().astype(np.uint8)
161
    image = nib_to_sitk(tensor, affine)
162
    sitk.WriteImage(image, str(path), use_compression)
163
164
165
def read_matrix(path: TypePath):
166
    """Read an affine transform and convert to tensor."""
167
    path = Path(path)
168
    suffix = path.suffix
169
    if suffix in ('.tfm', '.h5'):  # ITK
170
        tensor = _read_itk_matrix(path)
171
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
172
        tensor = _read_niftyreg_matrix(path)
173
    else:
174
        raise ValueError(f'Unknown suffix for transform file: "{suffix}"')
175
    return tensor
176
177
178
def write_matrix(matrix: torch.Tensor, path: TypePath):
179
    """Write an affine transform."""
180
    path = Path(path)
181
    suffix = path.suffix
182
    if suffix in ('.tfm', '.h5'):  # ITK
183
        _write_itk_matrix(matrix, path)
184
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
185
        _write_niftyreg_matrix(matrix, path)
186
187
188
def _to_itk_convention(matrix):
189
    """RAS to LPS"""
190
    matrix = np.dot(FLIPXY_44, matrix)
191
    matrix = np.dot(matrix, FLIPXY_44)
192
    matrix = np.linalg.inv(matrix)
193
    return matrix
194
195
196
def _from_itk_convention(matrix):
197
    """LPS to RAS"""
198
    matrix = np.dot(matrix, FLIPXY_44)
199
    matrix = np.dot(FLIPXY_44, matrix)
200
    matrix = np.linalg.inv(matrix)
201
    return matrix
202
203
204
def _read_itk_matrix(path):
205
    """Read an affine transform in ITK's .tfm format"""
206
    transform = sitk.ReadTransform(str(path))
207
    parameters = transform.GetParameters()
208
    rotation_parameters = parameters[:9]
209
    rotation_matrix = np.array(rotation_parameters).reshape(3, 3)
210
    translation_parameters = parameters[9:]
211
    translation_vector = np.array(translation_parameters).reshape(3, 1)
212
    matrix = np.hstack([rotation_matrix, translation_vector])
213
    homogeneous_matrix_lps = np.vstack([matrix, [0, 0, 0, 1]])
214
    homogeneous_matrix_ras = _from_itk_convention(homogeneous_matrix_lps)
215
    return torch.as_tensor(homogeneous_matrix_ras)
216
217
218
def _write_itk_matrix(matrix, tfm_path):
219
    """The tfm file contains the matrix from floating to reference."""
220
    transform = _matrix_to_itk_transform(matrix)
221
    transform.WriteTransform(str(tfm_path))
222
223
224
def _matrix_to_itk_transform(matrix, dimensions=3):
225
    matrix = _to_itk_convention(matrix)
226
    rotation = matrix[:dimensions, :dimensions].ravel().tolist()
227
    translation = matrix[:dimensions, 3].tolist()
228
    transform = sitk.AffineTransform(rotation, translation)
229
    return transform
230
231
232
def _read_niftyreg_matrix(trsf_path):
233
    """Read a NiftyReg matrix and return it as a NumPy array"""
234
    matrix = np.loadtxt(trsf_path)
235
    matrix = np.linalg.inv(matrix)
236
    return torch.as_tensor(matrix)
237
238
239
def _write_niftyreg_matrix(matrix, txt_path):
240
    """Write an affine transform in NiftyReg's .txt format (ref -> flo)"""
241
    matrix = np.linalg.inv(matrix)
242
    np.savetxt(txt_path, matrix, fmt='%.8f')
243
244
245
def get_rotation_and_spacing_from_affine(
246
        affine: np.ndarray,
247
        ) -> Tuple[np.ndarray, np.ndarray]:
248
    # From https://github.com/nipy/nibabel/blob/master/nibabel/orientations.py
249
    rotation_zoom = affine[:3, :3]
250
    spacing = np.sqrt(np.sum(rotation_zoom * rotation_zoom, axis=0))
251
    rotation = rotation_zoom / spacing
252
    return rotation, spacing
253
254
255
def nib_to_sitk(
256
        data: TypeData,
257
        affine: TypeData,
258
        force_3d: bool = False,
259
        force_4d: bool = False,
260
        ) -> sitk.Image:
261
    """Create a SimpleITK image from a tensor and a 4x4 affine matrix."""
262
    if data.ndim != 4:
263
        shape = tuple(data.shape)
264
        raise ValueError(f'Input must be 4D, but has shape {shape}')
265
    # Possibilities
266
    # (1, w, h, 1)
267
    # (c, w, h, 1)
268
    # (1, w, h, 1)
269
    # (c, w, h, d)
270
    array = np.asarray(data)
271
    affine = np.asarray(affine).astype(np.float64)
272
273
    is_multichannel = array.shape[0] > 1 and not force_4d
274
    is_2d = array.shape[3] == 1 and not force_3d
275
    if is_2d:
276
        array = array[..., 0]
277
    if not is_multichannel and not force_4d:
278
        array = array[0]
279
    array = array.transpose()  # (W, H, D, C) or (W, H, D)
280
    image = sitk.GetImageFromArray(array, isVector=is_multichannel)
281
282
    origin, spacing, direction = get_sitk_metadata_from_ras_affine(
283
        affine,
284
        is_2d=is_2d,
285
    )
286
    image.SetOrigin(origin)  # should I add a 4th value if force_4d?
287
    image.SetSpacing(spacing)
288
    image.SetDirection(direction)
289
290
    if data.ndim == 4:
291
        assert image.GetNumberOfComponentsPerPixel() == data.shape[0]
292
    num_spatial_dims = 2 if is_2d else 3
293
    assert image.GetSize() == data.shape[1: 1 + num_spatial_dims]
294
295
    return image
296
297
298
def sitk_to_nib(
299
        image: sitk.Image,
300
        keepdim: bool = False,
301
        ) -> Tuple[np.ndarray, np.ndarray]:
302
    data = sitk.GetArrayFromImage(image).transpose()
303
    data = check_uint_to_int(data)
304
    num_components = image.GetNumberOfComponentsPerPixel()
305
    if num_components == 1:
306
        data = data[np.newaxis]  # add channels dimension
307
    input_spatial_dims = image.GetDimension()
308
    if input_spatial_dims == 2:
309
        data = data[..., np.newaxis]
310
    if not keepdim:
311
        data = ensure_4d(data, num_spatial_dims=input_spatial_dims)
312
    assert data.shape[0] == num_components
313
    assert data.shape[1: 1 + input_spatial_dims] == image.GetSize()
314
    affine = get_ras_affine_from_sitk(image)
315
    return data, affine
316
317
318
def get_ras_affine_from_sitk(
319
        sitk_object: Union[sitk.Image, sitk.ImageFileReader],
320
        ) -> np.ndarray:
321
    spacing = np.array(sitk_object.GetSpacing())
322
    direction_lps = np.array(sitk_object.GetDirection())
323
    origin_lps = np.array(sitk_object.GetOrigin())
324
    if len(direction_lps) == 9:
325
        rotation_lps = direction_lps.reshape(3, 3)
326
    elif len(direction_lps) == 4:  # ignore first dimension if 2D (1, W, H, 1)
327
        rotation_lps_2d = direction_lps.reshape(2, 2)
328
        rotation_lps = np.eye(3)
329
        rotation_lps[:2, :2] = rotation_lps_2d
330
        spacing = *spacing, 1
331
        origin_lps = *origin_lps, 0
332
    rotation_ras = np.dot(FLIPXY_33, rotation_lps)
0 ignored issues
show
introduced by
The variable rotation_lps does not seem to be defined for all execution paths.
Loading history...
333
    rotation_ras_zoom = rotation_ras * spacing
334
    translation_ras = np.dot(FLIPXY_33, origin_lps)
335
    affine = np.eye(4)
336
    affine[:3, :3] = rotation_ras_zoom
337
    affine[:3, 3] = translation_ras
338
    return affine
339
340
341
def get_sitk_metadata_from_ras_affine(
342
        affine: np.ndarray,
343
        is_2d: bool = False,
344
        ) -> Tuple[TypeTripletFloat, TypeTripletFloat, TypeDirection]:
345
    rotation, spacing = get_rotation_and_spacing_from_affine(affine)
346
    origin = np.dot(FLIPXY_33, affine[:3, 3])
347
    direction = np.dot(FLIPXY_33, rotation)
348
    if is_2d:  # ignore first dimension if 2D (1, W, H, 1)
349
        direction = direction[:2, :2]
350
    direction = tuple(direction.flatten())
351
    return origin, spacing, direction
352
353
354
def ensure_4d(tensor: TypeData, num_spatial_dims=None) -> TypeData:
355
    # I wish named tensors were properly supported in PyTorch
356
    tensor = torch.as_tensor(tensor)
357
    num_dimensions = tensor.ndim
358
    if num_dimensions == 4:
359
        pass
360
    elif num_dimensions == 5:  # hope (W, H, D, 1, C)
361
        if tensor.shape[-2] == 1:
362
            tensor = tensor[..., 0, :]
363
            tensor = tensor.permute(3, 0, 1, 2)
364
        else:
365
            raise ValueError('5D is not supported for shape[-2] > 1')
366
    elif num_dimensions == 2:  # assume 2D monochannel (W, H)
367
        tensor = tensor[np.newaxis, ..., np.newaxis]  # (1, W, H, 1)
368
    elif num_dimensions == 3:  # 2D multichannel or 3D monochannel?
369
        if num_spatial_dims == 2:
370
            tensor = tensor[..., np.newaxis]  # (C, W, H, 1)
371
        elif num_spatial_dims == 3:  # (W, H, D)
372
            tensor = tensor[np.newaxis]  # (1, W, H, D)
373
        else:  # try to guess
374
            shape = tensor.shape
375
            maybe_rgb = 3 in (shape[0], shape[-1])
376
            if maybe_rgb:
377
                if shape[-1] == 3:  # (W, H, 3)
378
                    tensor = tensor.permute(2, 0, 1)  # (3, W, H)
379
                tensor = tensor[..., np.newaxis]  # (3, W, H, 1)
380
            else:  # (W, H, D)
381
                tensor = tensor[np.newaxis]  # (1, W, H, D)
382
    else:
383
        message = (
384
            f'{num_dimensions}D images not supported yet. Please create an'
385
            f' issue in {REPO_URL} if you would like support for them'
386
        )
387
        raise ValueError(message)
388
    assert tensor.ndim == 4
389
    return tensor
390
391
392
def check_uint_to_int(array):
393
    # This is because PyTorch won't take uint16 nor uint32
394
    if array.dtype == np.uint16:
395
        return array.astype(np.int32)
396
    if array.dtype == np.uint32:
397
        return array.astype(np.int64)
398
    return array
399