torchio.data.io.read_image()   A
last analyzed

Complexity

Conditions 3

Size

Total Lines 19
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 14
dl 0
loc 19
rs 9.7
c 0
b 0
f 0
cc 3
nop 1
1
import warnings
2
from pathlib import Path
3
from typing import Tuple, Union, Optional
4
5
import torch
6
import numpy as np
7
import nibabel as nib
8
import SimpleITK as sitk
9
10
from ..constants import REPO_URL
11
from ..typing import TypePath, TypeData, TypeTripletFloat, TypeDirection
12
13
14
# Matrices used to switch between LPS and RAS
15
FLIPXY_33 = np.diag([-1, -1, 1])
16
FLIPXY_44 = np.diag([-1, -1, 1, 1])
17
18
# Image formats that are typically 2D
19
formats = ['.jpg', '.jpeg', '.bmp', '.png', '.tif', '.tiff']
20
IMAGE_2D_FORMATS = formats + [s.upper() for s in formats]
21
22
23
def read_image(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
24
    try:
25
        result = _read_sitk(path)
26
    except RuntimeError as e:  # try with NiBabel
27
        message = (
28
            f'Error loading image with SimpleITK:\n{e}\n\nTrying NiBabel...'
29
        )
30
        warnings.warn(message)
31
        try:
32
            result = _read_nibabel(path)
33
        except nib.loadsave.ImageFileError as e:
34
            message = (
35
                f'File "{path}" not understood.'
36
                ' Check supported formats by at'
37
                ' https://simpleitk.readthedocs.io/en/master/IO.html#images'
38
                ' and https://nipy.org/nibabel/api.html#file-formats'
39
            )
40
            raise RuntimeError(message) from e
41
    return result
42
43
44
def _read_nibabel(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
45
    img = nib.load(str(path), mmap=False)
46
    data = img.get_fdata(dtype=np.float32)
47
    if data.ndim == 5:
48
        data = data[..., 0, :]
49
        data = data.transpose(3, 0, 1, 2)
50
    data = check_uint_to_int(data)
51
    tensor = torch.as_tensor(data)
52
    affine = img.affine
53
    return tensor, affine
54
55
56
def _read_sitk(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
57
    if Path(path).is_dir():  # assume DICOM
58
        image = _read_dicom(path)
59
    else:
60
        image = sitk.ReadImage(str(path))
61
    data, affine = sitk_to_nib(image, keepdim=True)
62
    data = check_uint_to_int(data)
63
    tensor = torch.as_tensor(data)
64
    return tensor, affine
65
66
67
def _read_dicom(directory: TypePath):
68
    directory = Path(directory)
69
    if not directory.is_dir():  # unreachable if called from _read_sitk
70
        raise FileNotFoundError(f'Directory "{directory}" not found')
71
    reader = sitk.ImageSeriesReader()
72
    dicom_names = reader.GetGDCMSeriesFileNames(str(directory))
73
    if not dicom_names:
74
        message = (
75
            f'The directory "{directory}"'
76
            ' does not seem to contain DICOM files'
77
        )
78
        raise FileNotFoundError(message)
79
    reader.SetFileNames(dicom_names)
80
    image = reader.Execute()
81
    return image
82
83
84
def read_shape(path: TypePath) -> Tuple[int, int, int, int]:
85
    reader = sitk.ImageFileReader()
86
    reader.SetFileName(str(path))
87
    reader.ReadImageInformation()
88
    num_channels = reader.GetNumberOfComponents()
89
    spatial_shape = reader.GetSize()
90
    num_dimensions = reader.GetDimension()
91
    if num_dimensions == 2:
92
        spatial_shape = *spatial_shape, 1
93
    elif num_dimensions == 4:  # assume bad NIfTI
94
        *spatial_shape, num_channels = spatial_shape
95
    shape = (num_channels,) + tuple(spatial_shape)
96
    return shape
97
98
99
def read_affine(path: TypePath) -> np.ndarray:
100
    reader = get_reader(path)
101
    affine = get_ras_affine_from_sitk(reader)
102
    return affine
103
104
105
def get_reader(path: TypePath, read: bool = True) -> sitk.ImageFileReader:
106
    reader = sitk.ImageFileReader()
107
    reader.SetFileName(str(path))
108
    if read:
109
        reader.ReadImageInformation()
110
    return reader
111
112
113
def write_image(
114
        tensor: torch.Tensor,
115
        affine: TypeData,
116
        path: TypePath,
117
        squeeze: Optional[bool] = None,
118
        ) -> None:
119
    args = tensor, affine, path
120
    try:
121
        _write_sitk(*args, squeeze=squeeze)
122
    except RuntimeError:  # try with NiBabel
123
        _write_nibabel(*args)
124
125
126
def _write_nibabel(
127
        tensor: TypeData,
128
        affine: TypeData,
129
        path: TypePath,
130
        ) -> None:
131
    """
132
    Expects a path with an extension that can be used by nibabel.save
133
    to write a NIfTI-1 image, such as '.nii.gz' or '.img'
134
    """
135
    assert tensor.ndim == 4
136
    num_components = tensor.shape[0]
137
138
    # NIfTI components must be at the end, in a 5D array
139
    if num_components == 1:
140
        tensor = tensor[0]
141
    else:
142
        tensor = tensor[np.newaxis].permute(2, 3, 4, 0, 1)
143
    suffix = Path(str(path).replace('.gz', '')).suffix
144
    if '.nii' in suffix:
145
        img = nib.Nifti1Image(np.asarray(tensor), affine)
146
    elif '.hdr' in suffix or '.img' in suffix:
147
        img = nib.Nifti1Pair(np.asarray(tensor), affine)
148
    else:
149
        raise nib.loadsave.ImageFileError
150
    if num_components > 1:
151
        img.header.set_intent('vector')
152
    img.header['qform_code'] = 1
153
    img.header['sform_code'] = 0
154
    nib.save(img, str(path))
155
156
157
def _write_sitk(
158
        tensor: torch.Tensor,
159
        affine: TypeData,
160
        path: TypePath,
161
        use_compression: bool = True,
162
        squeeze: Optional[bool] = None,
163
        ) -> None:
164
    assert tensor.ndim == 4
165
    path = Path(path)
166
    if path.suffix in ('.png', '.jpg', '.jpeg', '.bmp'):
167
        warnings.warn(
168
            f'Casting to uint 8 before saving to {path}',
169
            RuntimeWarning,
170
        )
171
        tensor = tensor.numpy().astype(np.uint8)
172
    if squeeze is None:
173
        force_3d = path.suffix not in IMAGE_2D_FORMATS
174
    else:
175
        force_3d = not squeeze
176
    image = nib_to_sitk(tensor, affine, force_3d=force_3d)
177
    sitk.WriteImage(image, str(path), use_compression)
178
179
180
def read_matrix(path: TypePath):
181
    """Read an affine transform and convert to tensor."""
182
    path = Path(path)
183
    suffix = path.suffix
184
    if suffix in ('.tfm', '.h5'):  # ITK
185
        tensor = _read_itk_matrix(path)
186
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
187
        tensor = _read_niftyreg_matrix(path)
188
    else:
189
        raise ValueError(f'Unknown suffix for transform file: "{suffix}"')
190
    return tensor
191
192
193
def write_matrix(matrix: torch.Tensor, path: TypePath):
194
    """Write an affine transform."""
195
    path = Path(path)
196
    suffix = path.suffix
197
    if suffix in ('.tfm', '.h5'):  # ITK
198
        _write_itk_matrix(matrix, path)
199
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
200
        _write_niftyreg_matrix(matrix, path)
201
202
203
def _to_itk_convention(matrix):
204
    """RAS to LPS"""
205
    matrix = np.dot(FLIPXY_44, matrix)
206
    matrix = np.dot(matrix, FLIPXY_44)
207
    matrix = np.linalg.inv(matrix)
208
    return matrix
209
210
211
def _from_itk_convention(matrix):
212
    """LPS to RAS"""
213
    matrix = np.dot(matrix, FLIPXY_44)
214
    matrix = np.dot(FLIPXY_44, matrix)
215
    matrix = np.linalg.inv(matrix)
216
    return matrix
217
218
219
def _read_itk_matrix(path):
220
    """Read an affine transform in ITK's .tfm format"""
221
    transform = sitk.ReadTransform(str(path))
222
    parameters = transform.GetParameters()
223
    rotation_parameters = parameters[:9]
224
    rotation_matrix = np.array(rotation_parameters).reshape(3, 3)
225
    translation_parameters = parameters[9:]
226
    translation_vector = np.array(translation_parameters).reshape(3, 1)
227
    matrix = np.hstack([rotation_matrix, translation_vector])
228
    homogeneous_matrix_lps = np.vstack([matrix, [0, 0, 0, 1]])
229
    homogeneous_matrix_ras = _from_itk_convention(homogeneous_matrix_lps)
230
    return torch.as_tensor(homogeneous_matrix_ras)
231
232
233
def _write_itk_matrix(matrix, tfm_path):
234
    """The tfm file contains the matrix from floating to reference."""
235
    transform = _matrix_to_itk_transform(matrix)
236
    transform.WriteTransform(str(tfm_path))
237
238
239
def _matrix_to_itk_transform(matrix, dimensions=3):
240
    matrix = _to_itk_convention(matrix)
241
    rotation = matrix[:dimensions, :dimensions].ravel().tolist()
242
    translation = matrix[:dimensions, 3].tolist()
243
    transform = sitk.AffineTransform(rotation, translation)
244
    return transform
245
246
247
def _read_niftyreg_matrix(trsf_path):
248
    """Read a NiftyReg matrix and return it as a NumPy array"""
249
    matrix = np.loadtxt(trsf_path)
250
    matrix = np.linalg.inv(matrix)
251
    return torch.as_tensor(matrix)
252
253
254
def _write_niftyreg_matrix(matrix, txt_path):
255
    """Write an affine transform in NiftyReg's .txt format (ref -> flo)"""
256
    matrix = np.linalg.inv(matrix)
257
    np.savetxt(txt_path, matrix, fmt='%.8f')
258
259
260
def get_rotation_and_spacing_from_affine(
261
        affine: np.ndarray,
262
        ) -> Tuple[np.ndarray, np.ndarray]:
263
    # From https://github.com/nipy/nibabel/blob/master/nibabel/orientations.py
264
    rotation_zoom = affine[:3, :3]
265
    spacing = np.sqrt(np.sum(rotation_zoom * rotation_zoom, axis=0))
266
    rotation = rotation_zoom / spacing
267
    return rotation, spacing
268
269
270
def nib_to_sitk(
271
        data: TypeData,
272
        affine: TypeData,
273
        force_3d: bool = False,
274
        force_4d: bool = False,
275
        ) -> sitk.Image:
276
    """Create a SimpleITK image from a tensor and a 4x4 affine matrix."""
277
    if data.ndim != 4:
278
        shape = tuple(data.shape)
279
        raise ValueError(f'Input must be 4D, but has shape {shape}')
280
    # Possibilities
281
    # (1, w, h, 1)
282
    # (c, w, h, 1)
283
    # (1, w, h, 1)
284
    # (c, w, h, d)
285
    array = np.asarray(data)
286
    affine = np.asarray(affine).astype(np.float64)
287
288
    is_multichannel = array.shape[0] > 1 and not force_4d
289
    is_2d = array.shape[3] == 1 and not force_3d
290
    if is_2d:
291
        array = array[..., 0]
292
    if not is_multichannel and not force_4d:
293
        array = array[0]
294
    array = array.transpose()  # (W, H, D, C) or (W, H, D)
295
    image = sitk.GetImageFromArray(array, isVector=is_multichannel)
296
297
    origin, spacing, direction = get_sitk_metadata_from_ras_affine(
298
        affine,
299
        is_2d=is_2d,
300
    )
301
    image.SetOrigin(origin)  # should I add a 4th value if force_4d?
302
    image.SetSpacing(spacing)
303
    image.SetDirection(direction)
304
305
    if data.ndim == 4:
306
        assert image.GetNumberOfComponentsPerPixel() == data.shape[0]
307
    num_spatial_dims = 2 if is_2d else 3
308
    assert image.GetSize() == data.shape[1:1 + num_spatial_dims]
309
310
    return image
311
312
313
def sitk_to_nib(
314
        image: sitk.Image,
315
        keepdim: bool = False,
316
        ) -> Tuple[np.ndarray, np.ndarray]:
317
    data = sitk.GetArrayFromImage(image).transpose()
318
    data = check_uint_to_int(data)
319
    num_components = image.GetNumberOfComponentsPerPixel()
320
    if num_components == 1:
321
        data = data[np.newaxis]  # add channels dimension
322
    input_spatial_dims = image.GetDimension()
323
    if input_spatial_dims == 2:
324
        data = data[..., np.newaxis]
325
    elif input_spatial_dims == 4:  # probably a bad NIfTI (1, sx, sy, sz, c)
326
        # Try to fix it
327
        num_components = data.shape[-1]
328
        data = data[0]
329
        data = data.transpose(3, 0, 1, 2)
330
        input_spatial_dims = 3
331
    if not keepdim:
332
        data = ensure_4d(data, num_spatial_dims=input_spatial_dims)
333
    assert data.shape[0] == num_components
334
    affine = get_ras_affine_from_sitk(image)
335
    return data, affine
336
337
338
def get_ras_affine_from_sitk(
339
        sitk_object: Union[sitk.Image, sitk.ImageFileReader],
340
        ) -> np.ndarray:
341
    spacing = np.array(sitk_object.GetSpacing())
342
    direction_lps = np.array(sitk_object.GetDirection())
343
    origin_lps = np.array(sitk_object.GetOrigin())
344
    direction_length = len(direction_lps)
345
    if direction_length == 9:
346
        rotation_lps = direction_lps.reshape(3, 3)
347
    elif direction_length == 4:  # ignore last dimension if 2D (1, W, H, 1)
348
        rotation_lps_2d = direction_lps.reshape(2, 2)
349
        rotation_lps = np.eye(3)
350
        rotation_lps[:2, :2] = rotation_lps_2d
351
        spacing = np.append(spacing, 1)
352
        origin_lps = np.append(origin_lps, 0)
353
    elif direction_length == 16:  # probably a bad NIfTI. Let's try to fix it
354
        rotation_lps = direction_lps.reshape(4, 4)[:3, :3]
355
        spacing = spacing[:-1]
356
        origin_lps = origin_lps[:-1]
357
    rotation_ras = np.dot(FLIPXY_33, rotation_lps)
0 ignored issues
show
introduced by
The variable rotation_lps does not seem to be defined for all execution paths.
Loading history...
358
    rotation_ras_zoom = rotation_ras * spacing
359
    translation_ras = np.dot(FLIPXY_33, origin_lps)
360
    affine = np.eye(4)
361
    affine[:3, :3] = rotation_ras_zoom
362
    affine[:3, 3] = translation_ras
363
    return affine
364
365
366
def get_sitk_metadata_from_ras_affine(
367
        affine: np.ndarray,
368
        is_2d: bool = False,
369
        lps: bool = True,
370
        ) -> Tuple[TypeTripletFloat, TypeTripletFloat, TypeDirection]:
371
    direction_ras, spacing_array = get_rotation_and_spacing_from_affine(affine)
372
    origin_ras = affine[:3, 3]
373
    origin_lps = np.dot(FLIPXY_33, origin_ras)
374
    direction_lps = np.dot(FLIPXY_33, direction_ras)
375
    if is_2d:  # ignore orientation if 2D (1, W, H, 1)
376
        direction_lps = np.diag((-1, -1)).astype(np.float64)
377
        direction_ras = np.diag((1, 1)).astype(np.float64)
378
    origin_array = origin_lps if lps else origin_ras
379
    direction_array = direction_lps if lps else direction_ras
380
    direction_array = direction_array.flatten()
381
    # The following are to comply with typing hints
382
    # (there must be prettier ways to do this)
383
    ox, oy, oz = origin_array
384
    sx, sy, sz = spacing_array
385
    if is_2d:
386
        d1, d2, d3, d4 = direction_array
387
        direction = d1, d2, d3, d4
388
    else:
389
        d1, d2, d3, d4, d5, d6, d7, d8, d9 = direction_array
390
        direction = d1, d2, d3, d4, d5, d6, d7, d8, d9
391
    origin = ox, oy, oz
392
    spacing = sx, sy, sz
393
    return origin, spacing, direction
394
395
396
def ensure_4d(tensor: TypeData, num_spatial_dims=None) -> TypeData:
397
    # I wish named tensors were properly supported in PyTorch
398
    tensor = torch.as_tensor(tensor)
399
    num_dimensions = tensor.ndim
400
    if num_dimensions == 4:
401
        pass
402
    elif num_dimensions == 5:  # hope (W, H, D, 1, C)
403
        if tensor.shape[-2] == 1:
404
            tensor = tensor[..., 0, :]
405
            tensor = tensor.permute(3, 0, 1, 2)
406
        else:
407
            raise ValueError('5D is not supported for shape[-2] > 1')
408
    elif num_dimensions == 2:  # assume 2D monochannel (W, H)
409
        tensor = tensor[np.newaxis, ..., np.newaxis]  # (1, W, H, 1)
410
    elif num_dimensions == 3:  # 2D multichannel or 3D monochannel?
411
        if num_spatial_dims == 2:
412
            tensor = tensor[..., np.newaxis]  # (C, W, H, 1)
413
        elif num_spatial_dims == 3:  # (W, H, D)
414
            tensor = tensor[np.newaxis]  # (1, W, H, D)
415
        else:  # try to guess
416
            shape = tensor.shape
417
            maybe_rgb = 3 in (shape[0], shape[-1])
418
            if maybe_rgb:
419
                if shape[-1] == 3:  # (W, H, 3)
420
                    tensor = tensor.permute(2, 0, 1)  # (3, W, H)
421
                tensor = tensor[..., np.newaxis]  # (3, W, H, 1)
422
            else:  # (W, H, D)
423
                tensor = tensor[np.newaxis]  # (1, W, H, D)
424
    else:
425
        message = (
426
            f'{num_dimensions}D images not supported yet. Please create an'
427
            f' issue in {REPO_URL} if you would like support for them'
428
        )
429
        raise ValueError(message)
430
    assert tensor.ndim == 4
431
    return tensor
432
433
434
def check_uint_to_int(array):
435
    # This is because PyTorch won't take uint16 nor uint32
436
    if array.dtype == np.uint16:
437
        return array.astype(np.int32)
438
    if array.dtype == np.uint32:
439
        return array.astype(np.int64)
440
    return array
441