torchio.data.io.write_image()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 10
nop 4
dl 0
loc 11
rs 9.9
c 0
b 0
f 0
1
from __future__ import annotations
2
3
import warnings
4
from pathlib import Path
5
6
import nibabel as nib
7
import numpy as np
8
import numpy.typing as npt
9
import SimpleITK as sitk
10
import torch
11
from nibabel.filebasedimages import ImageFileError
12
from nibabel.spatialimages import SpatialImage
13
14
from ..constants import REPO_URL
15
from ..types import TypeData
16
from ..types import TypeDataAffine
17
from ..types import TypeDirection
18
from ..types import TypeDoubletInt
19
from ..types import TypePath
20
from ..types import TypeQuartetInt
21
from ..types import TypeTripletFloat
22
from ..types import TypeTripletInt
23
24
# Matrices used to switch between LPS and RAS
25
FLIPXY_33 = np.diag([-1, -1, 1])
26
FLIPXY_44 = np.diag([-1, -1, 1, 1])
27
28
# Image formats that are typically 2D
29
formats = ['.jpg', '.jpeg', '.bmp', '.png', '.tif', '.tiff']
30
IMAGE_2D_FORMATS = formats + [s.upper() for s in formats]
31
32
33
def read_image(path: TypePath) -> TypeDataAffine:
34
    try:
35
        result = _read_sitk(path)
36
    except RuntimeError as e:  # try with NiBabel
37
        message = f'Error loading image with SimpleITK:\n{e}\n\nTrying NiBabel...'
38
        warnings.warn(message, stacklevel=2)
39
        try:
40
            result = _read_nibabel(path)
41
        except ImageFileError as e:
42
            message = (
43
                f'File "{path}" not understood.'
44
                ' Check supported formats by at'
45
                ' https://simpleitk.readthedocs.io/en/master/IO.html#images'
46
                ' and https://nipy.org/nibabel/api.html#file-formats'
47
            )
48
            raise RuntimeError(message) from e
49
    return result
50
51
52
def _read_nibabel(path: TypePath) -> TypeDataAffine:
53
    img: SpatialImage = nib.load(str(path), mmap=False)  # type: ignore[assignment]
54
    data = img.get_fdata(dtype=np.float32)
55
    if data.ndim == 5:
56
        data = data[..., 0, :]
57
        data = data.transpose(3, 0, 1, 2)
58
    data = check_uint_to_int(data)
59
    tensor = torch.as_tensor(data)
60
    affine = img.affine
61
    assert isinstance(affine, np.ndarray)
62
    return tensor, affine
63
64
65
def _read_sitk(path: TypePath) -> TypeDataAffine:
66
    if Path(path).is_dir():  # assume DICOM
67
        image = _read_dicom(path)
68
    else:
69
        image = sitk.ReadImage(str(path))
70
    data, affine = sitk_to_nib(image, keepdim=True)
71
    data = check_uint_to_int(data)
72
    tensor = torch.as_tensor(data)
73
    return tensor, affine
74
75
76
def _read_dicom(directory: TypePath):
77
    directory = Path(directory)
78
    if not directory.is_dir():  # unreachable if called from _read_sitk
79
        raise FileNotFoundError(f'Directory "{directory}" not found')
80
    reader = sitk.ImageSeriesReader()
81
    dicom_names = reader.GetGDCMSeriesFileNames(str(directory))
82
    if not dicom_names:
83
        message = f'The directory "{directory}" does not seem to contain DICOM files'
84
        raise FileNotFoundError(message)
85
    reader.SetFileNames(dicom_names)
86
    image = reader.Execute()
87
    return image
88
89
90
def read_shape(path: TypePath) -> TypeQuartetInt:
91
    reader = sitk.ImageFileReader()
92
    reader.SetFileName(str(path))
93
    reader.ReadImageInformation()
94
    num_channels = reader.GetNumberOfComponents()
95
    num_dimensions = reader.GetDimension()
96
    assert 2 <= num_dimensions <= 4
97
    if num_dimensions == 2:
98
        spatial_shape_2d: TypeDoubletInt = reader.GetSize()
99
        assert len(spatial_shape_2d) == 2
100
        si, sj = spatial_shape_2d
101
        sk = 1
102
    elif num_dimensions == 4:
103
        # We assume bad NIfTI file (channels encoded as spatial dimension)
104
        spatial_shape_4d: TypeQuartetInt = reader.GetSize()
105
        assert len(spatial_shape_4d) == 4
106
        si, sj, sk, num_channels = spatial_shape_4d
107
    elif num_dimensions == 3:
108
        spatial_shape_3d: TypeTripletInt = reader.GetSize()
109
        assert len(spatial_shape_3d) == 3
110
        si, sj, sk = spatial_shape_3d
111
    shape = num_channels, si, sj, sk
0 ignored issues
show
introduced by
The variable si does not seem to be defined for all execution paths.
Loading history...
introduced by
The variable sj does not seem to be defined for all execution paths.
Loading history...
introduced by
The variable sk does not seem to be defined for all execution paths.
Loading history...
112
    return shape
113
114
115
def read_affine(path: TypePath) -> np.ndarray:
116
    reader = get_reader(path)
117
    affine = get_ras_affine_from_sitk(reader)
118
    return affine
119
120
121
def get_reader(path: TypePath, read: bool = True) -> sitk.ImageFileReader:
122
    reader = sitk.ImageFileReader()
123
    reader.SetFileName(str(path))
124
    if read:
125
        reader.ReadImageInformation()
126
    return reader
127
128
129
def write_image(
130
    tensor: torch.Tensor,
131
    affine: TypeData,
132
    path: TypePath,
133
    squeeze: bool | None = None,
134
) -> None:
135
    args = tensor, affine, path
136
    try:
137
        _write_sitk(*args, squeeze=squeeze)
138
    except RuntimeError:  # try with NiBabel
139
        _write_nibabel(*args)
140
141
142
def _write_nibabel(
143
    tensor: torch.Tensor,
144
    affine: TypeData,
145
    path: TypePath,
146
) -> None:
147
    """Write an image using NiBabel.
148
149
    Expects a path with an extension that can be used by nibabel.save to
150
    write a NIfTI-1 image, such as '.nii.gz' or '.img'
151
    """
152
    assert tensor.ndim == 4
153
    num_components = tensor.shape[0]
154
155
    # NIfTI components must be at the end, in a 5D array
156
    if num_components == 1:
157
        tensor = tensor[0]
158
    else:
159
        tensor = tensor[np.newaxis].permute(2, 3, 4, 0, 1)
160
    suffix = Path(str(path).replace('.gz', '')).suffix
161
    img: nib.nifti1.Nifti1Image | nib.nifti1.Nifti1Pair
162
    if '.nii' in suffix:
163
        img = nib.nifti1.Nifti1Image(np.asarray(tensor), affine)
164
    elif '.hdr' in suffix or '.img' in suffix:
165
        img = nib.nifti1.Nifti1Pair(np.asarray(tensor), affine)
166
    else:
167
        raise ImageFileError
168
    assert isinstance(img.header, nib.nifti1.Nifti1Header)
169
    if num_components > 1:
170
        img.header.set_intent('vector')
171
    img.header['qform_code'] = 1
172
    img.header['sform_code'] = 0
173
    nib.loadsave.save(img, str(path))
174
175
176
def _write_sitk(
177
    tensor: torch.Tensor,
178
    affine: TypeData,
179
    path: TypePath,
180
    use_compression: bool = True,
181
    squeeze: bool | None = None,
182
) -> None:
183
    assert tensor.ndim == 4
184
    path = Path(path)
185
    array = tensor.numpy()
186
    if path.suffix in ('.png', '.jpg', '.jpeg', '.bmp'):
187
        warnings.warn(
188
            f'Casting to uint 8 before saving to {path}',
189
            RuntimeWarning,
190
            stacklevel=2,
191
        )
192
        array = array.astype(np.uint8)
193
    if squeeze is None:
194
        force_3d = path.suffix not in IMAGE_2D_FORMATS
195
    else:
196
        force_3d = not squeeze
197
    image = nib_to_sitk(array, affine, force_3d=force_3d)
198
    sitk.WriteImage(image, str(path), use_compression)
199
200
201
def read_matrix(path: TypePath):
202
    """Read an affine transform and convert to tensor."""
203
    path = Path(path)
204
    suffix = path.suffix
205
    if suffix in ('.tfm', '.h5'):  # ITK
206
        tensor = _read_itk_matrix(path)
207
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
208
        tensor = _read_niftyreg_matrix(path)
209
    else:
210
        raise ValueError(f'Unknown suffix for transform file: "{suffix}"')
211
    return tensor
212
213
214
def write_matrix(matrix: torch.Tensor, path: TypePath):
215
    """Write an affine transform."""
216
    path = Path(path)
217
    suffix = path.suffix
218
    if suffix in ('.tfm', '.h5'):  # ITK
219
        _write_itk_matrix(matrix, path)
220
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
221
        _write_niftyreg_matrix(matrix, path)
222
223
224
def _to_itk_convention(matrix: TypeData) -> np.ndarray:
225
    """RAS to LPS."""
226
    if isinstance(matrix, torch.Tensor):
227
        matrix = matrix.numpy()
228
    matrix = np.dot(FLIPXY_44, matrix)
229
    matrix = np.dot(matrix, FLIPXY_44)
230
    matrix = np.linalg.inv(matrix)
231
    return matrix
232
233
234
def _from_itk_convention(matrix: TypeData) -> np.ndarray:
235
    """LPS to RAS."""
236
    matrix = np.dot(matrix, FLIPXY_44)
237
    matrix = np.dot(FLIPXY_44, matrix)
238
    matrix = np.linalg.inv(matrix)
239
    return matrix
240
241
242
def _read_itk_matrix(path: TypePath) -> torch.Tensor:
243
    """Read an affine transform in ITK's .tfm format."""
244
    transform = sitk.ReadTransform(str(path))
245
    parameters = transform.GetParameters()
246
    rotation_parameters = parameters[:9]
247
    rotation_matrix = np.array(rotation_parameters).reshape(3, 3)
248
    translation_parameters = parameters[9:]
249
    translation_vector = np.array(translation_parameters).reshape(3, 1)
250
    matrix = np.hstack([rotation_matrix, translation_vector])
251
    homogeneous_matrix_lps = np.vstack([matrix, [0, 0, 0, 1]])
252
    homogeneous_matrix_ras = _from_itk_convention(homogeneous_matrix_lps)
253
    return torch.as_tensor(homogeneous_matrix_ras)
254
255
256
def _write_itk_matrix(matrix: TypeData, tfm_path: TypePath) -> None:
257
    """The tfm file contains the matrix from floating to reference."""
258
    transform = _matrix_to_itk_transform(matrix)
259
    transform.WriteTransform(str(tfm_path))
260
261
262
def _matrix_to_itk_transform(
263
    matrix: TypeData,
264
    dimensions: int = 3,
265
) -> sitk.AffineTransform:
266
    matrix = _to_itk_convention(matrix)
267
    rotation = matrix[:dimensions, :dimensions].ravel().tolist()
268
    translation = matrix[:dimensions, 3].tolist()
269
    transform = sitk.AffineTransform(rotation, translation)
270
    return transform
271
272
273
def _read_niftyreg_matrix(trsf_path: TypePath) -> torch.Tensor:
274
    """Read a NiftyReg matrix and return it as a NumPy array."""
275
    read_matrix = np.loadtxt(trsf_path).astype(np.float64)
276
    inverted = np.linalg.inv(read_matrix)
277
    return torch.from_numpy(inverted)
278
279
280
def _write_niftyreg_matrix(matrix: TypeData, txt_path: TypePath) -> None:
281
    """Write an affine transform in NiftyReg's .txt format (ref -> flo)"""
282
    matrix = np.linalg.inv(matrix)
283
    np.savetxt(txt_path, matrix, fmt='%.8f')
284
285
286
def get_rotation_and_spacing_from_affine(
287
    affine: np.ndarray,
288
) -> tuple[np.ndarray, np.ndarray]:
289
    # From https://github.com/nipy/nibabel/blob/master/nibabel/orientations.py
290
    rotation_zoom = affine[:3, :3]
291
    spacing = np.sqrt(np.sum(rotation_zoom * rotation_zoom, axis=0))
292
    rotation = rotation_zoom / spacing
293
    return rotation, spacing
294
295
296
def nib_to_sitk(
297
    data: TypeData,
298
    affine: TypeData,
299
    force_3d: bool = False,
300
    force_4d: bool = False,
301
) -> sitk.Image:
302
    """Create a SimpleITK image from a tensor and a 4x4 affine matrix."""
303
    if data.ndim != 4:
304
        shape = tuple(data.shape)
305
        raise ValueError(f'Input must be 4D, but has shape {shape}')
306
    # Possibilities
307
    # (1, w, h, 1)
308
    # (c, w, h, 1)
309
    # (1, w, h, d)
310
    # (c, w, h, d)
311
    array = np.asarray(data)
312
    affine = np.asarray(affine).astype(np.float64)
313
314
    is_multichannel = array.shape[0] > 1 and not force_4d
315
    is_2d = array.shape[3] == 1 and not force_3d
316
    if is_2d:
317
        array = array[..., 0]
318
    if not is_multichannel and not force_4d:
319
        array = array[0]
320
    array = array.transpose()  # (W, H, D, C) or (W, H, D)
321
    image = sitk.GetImageFromArray(array, isVector=is_multichannel)
322
323
    origin, spacing, direction = get_sitk_metadata_from_ras_affine(
324
        affine,
325
        is_2d=is_2d,
326
    )
327
    image.SetOrigin(origin)  # should I add a 4th value if force_4d?
328
    image.SetSpacing(spacing)
329
    image.SetDirection(direction)
330
331
    if data.ndim == 4:
332
        assert image.GetNumberOfComponentsPerPixel() == data.shape[0]
333
    num_spatial_dims = 2 if is_2d else 3
334
    assert image.GetSize() == data.shape[1 : 1 + num_spatial_dims]
335
336
    return image
337
338
339
def sitk_to_nib(
340
    image: sitk.Image,
341
    keepdim: bool = False,
342
) -> tuple[np.ndarray, np.ndarray]:
343
    data = sitk.GetArrayFromImage(image).transpose()
344
    data = check_uint_to_int(data)
345
    num_components = image.GetNumberOfComponentsPerPixel()
346
    if num_components == 1:
347
        data = data[np.newaxis]  # add channels dimension
348
    input_spatial_dims = image.GetDimension()
349
    if input_spatial_dims == 2:
350
        data = data[..., np.newaxis]
351
    elif input_spatial_dims == 4:  # probably a bad NIfTI (1, sx, sy, sz, c)
352
        # Try to fix it
353
        num_components = data.shape[-1]
354
        data = data[0]
355
        data = data.transpose(3, 0, 1, 2)
356
        input_spatial_dims = 3
357
    if not keepdim:
358
        data = ensure_4d(data, num_spatial_dims=input_spatial_dims).numpy()
359
    assert data.shape[0] == num_components
360
    affine = get_ras_affine_from_sitk(image)
361
    return data, affine
362
363
364
def get_ras_affine_from_sitk(
365
    sitk_object: sitk.Image | sitk.ImageFileReader,
366
) -> np.ndarray:
367
    spacing = np.array(sitk_object.GetSpacing(), dtype=np.float64)
368
    direction_lps = np.array(sitk_object.GetDirection(), dtype=np.float64)
369
    origin_lps = np.array(sitk_object.GetOrigin(), dtype=np.float64)
370
    direction_length = len(direction_lps)
371
    rotation_lps: npt.NDArray[np.float64]
372
    if direction_length == 9:
373
        rotation_lps = direction_lps.reshape(3, 3)
374
    elif direction_length == 4:  # ignore last dimension if 2D (1, W, H, 1)
375
        rotation_lps_2d = direction_lps.reshape(2, 2)
376
        rotation_lps = np.eye(3)
377
        rotation_lps[:2, :2] = rotation_lps_2d
378
        spacing = np.append(spacing, 1)
379
        origin_lps = np.append(origin_lps, 0)
380
    elif direction_length == 16:  # probably a bad NIfTI. Let's try to fix it
381
        rotation_lps = direction_lps.reshape(4, 4)[:3, :3]
382
        spacing = spacing[:-1]
383
        origin_lps = origin_lps[:-1]
384
    rotation_ras = np.dot(FLIPXY_33, rotation_lps)
0 ignored issues
show
introduced by
The variable rotation_lps does not seem to be defined for all execution paths.
Loading history...
385
    rotation_ras_zoom = rotation_ras * spacing
386
    translation_ras = np.dot(FLIPXY_33, origin_lps)
387
    affine = np.eye(4)
388
    affine[:3, :3] = rotation_ras_zoom
389
    affine[:3, 3] = translation_ras
390
    return affine
391
392
393
def get_sitk_metadata_from_ras_affine(
394
    affine: np.ndarray,
395
    is_2d: bool = False,
396
    lps: bool = True,
397
) -> tuple[TypeTripletFloat, TypeTripletFloat, TypeDirection]:
398
    direction_ras, spacing_array = get_rotation_and_spacing_from_affine(affine)
399
    origin_ras = affine[:3, 3]
400
    origin_lps = np.dot(FLIPXY_33, origin_ras)
401
    direction_lps = np.dot(FLIPXY_33, direction_ras)
402
    if is_2d:  # ignore orientation if 2D (1, W, H, 1)
403
        direction_lps = np.diag((-1, -1)).astype(np.float64)
404
        direction_ras = np.diag((1, 1)).astype(np.float64)
405
    origin_array = origin_lps if lps else origin_ras
406
    direction_array = direction_lps if lps else direction_ras
407
    direction_array = direction_array.flatten()
408
    # The following are to comply with mypy
409
    # (although there must be prettier ways to do this)
410
    ox, oy, oz = origin_array
411
    sx, sy, sz = spacing_array
412
    direction: TypeDirection
413
    if is_2d:
414
        d1, d2, d3, d4 = direction_array
415
        direction = d1, d2, d3, d4
416
    else:
417
        d1, d2, d3, d4, d5, d6, d7, d8, d9 = direction_array
418
        direction = d1, d2, d3, d4, d5, d6, d7, d8, d9
419
    origin = ox, oy, oz
420
    spacing = sx, sy, sz
421
    return origin, spacing, direction
422
423
424
def ensure_4d(tensor: TypeData, num_spatial_dims=None) -> torch.Tensor:
425
    # I wish named tensors were properly supported in PyTorch
426
    tensor = torch.as_tensor(tensor)
427
    num_dimensions = tensor.ndim
428
    if num_dimensions == 4:
429
        pass
430
    elif num_dimensions == 5:  # hope (W, H, D, 1, C)
431
        if tensor.shape[-2] == 1:
432
            tensor = tensor[..., 0, :]
433
            tensor = tensor.permute(3, 0, 1, 2)
434
        else:
435
            raise ValueError('5D is not supported for shape[-2] > 1')
436
    elif num_dimensions == 2:  # assume 2D monochannel (W, H)
437
        tensor = tensor[np.newaxis, ..., np.newaxis]  # (1, W, H, 1)
438
    elif num_dimensions == 3:  # 2D multichannel or 3D monochannel?
439
        if num_spatial_dims == 2:
440
            tensor = tensor[..., np.newaxis]  # (C, W, H, 1)
441
        elif num_spatial_dims == 3:  # (W, H, D)
442
            tensor = tensor[np.newaxis]  # (1, W, H, D)
443
        else:  # try to guess
444
            shape = tensor.shape
445
            maybe_rgb = 3 in (shape[0], shape[-1])
446
            if maybe_rgb:
447
                if shape[-1] == 3:  # (W, H, 3)
448
                    tensor = tensor.permute(2, 0, 1)  # (3, W, H)
449
                tensor = tensor[..., np.newaxis]  # (3, W, H, 1)
450
            else:  # (W, H, D)
451
                tensor = tensor[np.newaxis]  # (1, W, H, D)
452
    else:
453
        message = (
454
            f'{num_dimensions}D images not supported yet. Please create an'
455
            f' issue in {REPO_URL} if you would like support for them'
456
        )
457
        raise ValueError(message)
458
    assert tensor.ndim == 4
459
    return tensor
460
461
462
def check_uint_to_int(array: np.ndarray) -> np.ndarray:
463
    # This is because PyTorch won't take uint16 nor uint32
464
    if array.dtype == np.uint16:
465
        return array.astype(np.int32)
466
    if array.dtype == np.uint32:
467
        return array.astype(np.int64)
468
    return array
469