Passed
Push — master ( cf656a...e4342b )
by Fernando
01:45
created

torchio.data.io.write_image()   A

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 10
nop 4
dl 0
loc 11
rs 9.9
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Tuple
4
5
import torch
6
import numpy as np
7
import nibabel as nib
8
import SimpleITK as sitk
9
10
from ..constants import REPO_URL
11
from ..typing import TypePath, TypeData, TypeTripletInt
12
13
14
FLIPXY = np.diag([-1, -1, 1, 1])
15
16
17
def read_image(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
18
    try:
19
        result = _read_sitk(path)
20
    except RuntimeError:  # try with NiBabel
21
        try:
22
            result = _read_nibabel(path)
23
        except nib.loadsave.ImageFileError:
24
            message = (
25
                f'File "{path}" not understood.'
26
                ' Check supported formats by at'
27
                ' https://simpleitk.readthedocs.io/en/master/IO.html#images'
28
                ' and https://nipy.org/nibabel/api.html#file-formats'
29
            )
30
            raise RuntimeError(message)
31
    return result
32
33
34
def _read_nibabel(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
35
    img = nib.load(str(path), mmap=False)
36
    data = img.get_fdata(dtype=np.float32)
37
    if data.ndim == 5:
38
        data = data[..., 0, :]
39
        data = data.transpose(3, 0, 1, 2)
40
    data = check_uint_to_int(data)
41
    tensor = torch.as_tensor(data)
42
    affine = img.affine
43
    return tensor, affine
44
45
46
def _read_sitk(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
47
    if Path(path).is_dir():  # assume DICOM
48
        image = _read_dicom(path)
49
    else:
50
        image = sitk.ReadImage(str(path))
51
    data, affine = sitk_to_nib(image, keepdim=True)
52
    data = check_uint_to_int(data)
53
    tensor = torch.as_tensor(data)
54
    return tensor, affine
55
56
57
def _read_dicom(directory: TypePath):
58
    directory = Path(directory)
59
    if not directory.is_dir():  # unreachable if called from _read_sitk
60
        raise FileNotFoundError(f'Directory "{directory}" not found')
61
    reader = sitk.ImageSeriesReader()
62
    dicom_names = reader.GetGDCMSeriesFileNames(str(directory))
63
    if not dicom_names:
64
        message = (
65
            f'The directory "{directory}"'
66
            ' does not seem to contain DICOM files'
67
        )
68
        raise FileNotFoundError(message)
69
    reader.SetFileNames(dicom_names)
70
    image = reader.Execute()
71
    return image
72
73
74
def read_shape(path: TypePath) -> Tuple[int, int, int, int]:
75
    reader = sitk.ImageFileReader()
76
    reader.SetFileName(str(path))
77
    reader.ReadImageInformation()
78
    num_channels = reader.GetNumberOfComponents()
79
    spatial_shape = reader.GetSize()
80
    if reader.GetDimension() == 2:
81
        spatial_shape = *spatial_shape, 1
82
    return (num_channels,) + spatial_shape
83
84
85
def read_affine(path: TypePath) -> np.ndarray:
86
    reader = sitk.ImageFileReader()
87
    reader.SetFileName(str(path))
88
    reader.ReadImageInformation()
89
    affine = get_ras_affine_from_sitk_metadata(
90
        reader.GetSpacing(),
91
        reader.GetDirection(),
92
        reader.GetOrigin(),
93
    )
94
    return affine
95
96
97
def write_image(
98
        tensor: torch.Tensor,
99
        affine: TypeData,
100
        path: TypePath,
101
        squeeze: bool = True,
102
        ) -> None:
103
    args = tensor, affine, path
104
    try:
105
        _write_sitk(*args)
106
    except RuntimeError:  # try with NiBabel
107
        _write_nibabel(*args, squeeze=squeeze)
108
109
110
def _write_nibabel(
111
        tensor: TypeData,
112
        affine: TypeData,
113
        path: TypePath,
114
        squeeze: bool = False,
115
        ) -> None:
116
    """
117
    Expects a path with an extension that can be used by nibabel.save
118
    to write a NIfTI-1 image, such as '.nii.gz' or '.img'
119
    """
120
    assert tensor.ndim == 4
121
    num_components = tensor.shape[0]
122
123
    # NIfTI components must be at the end, in a 5D array
124
    if num_components == 1:
125
        tensor = tensor[0]
126
    else:
127
        tensor = tensor[np.newaxis].permute(2, 3, 4, 0, 1)
128
    tensor = tensor.squeeze() if squeeze else tensor
129
    suffix = Path(str(path).replace('.gz', '')).suffix
130
    if '.nii' in suffix:
131
        img = nib.Nifti1Image(np.asarray(tensor), affine)
132
    elif '.hdr' in suffix or '.img' in suffix:
133
        img = nib.Nifti1Pair(np.asarray(tensor), affine)
134
    else:
135
        raise nib.loadsave.ImageFileError
136
    if num_components > 1:
137
        img.header.set_intent('vector')
138
    img.header['qform_code'] = 1
139
    img.header['sform_code'] = 0
140
    nib.save(img, str(path))
141
142
143
def _write_sitk(
144
        tensor: torch.Tensor,
145
        affine: TypeData,
146
        path: TypePath,
147
        use_compression: bool = True,
148
        ) -> None:
149
    assert tensor.ndim == 4
150
    path = Path(path)
151
    if path.suffix in ('.png', '.jpg', '.jpeg'):
152
        warnings.warn(
153
            f'Casting to uint 8 before saving to {path}',
154
            RuntimeWarning,
155
        )
156
        tensor = tensor.numpy().astype(np.uint8)
157
    image = nib_to_sitk(tensor, affine)
158
    sitk.WriteImage(image, str(path), use_compression)
159
160
161
def read_matrix(path: TypePath):
162
    """Read an affine transform and convert to tensor."""
163
    path = Path(path)
164
    suffix = path.suffix
165
    if suffix in ('.tfm', '.h5'):  # ITK
166
        tensor = _read_itk_matrix(path)
167
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
168
        tensor = _read_niftyreg_matrix(path)
169
    else:
170
        raise ValueError(f'Unknown suffix for transform file: "{suffix}"')
171
    return tensor
172
173
174
def write_matrix(matrix: torch.Tensor, path: TypePath):
175
    """Write an affine transform."""
176
    path = Path(path)
177
    suffix = path.suffix
178
    if suffix in ('.tfm', '.h5'):  # ITK
179
        _write_itk_matrix(matrix, path)
180
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
181
        _write_niftyreg_matrix(matrix, path)
182
183
184
def _to_itk_convention(matrix):
185
    """RAS to LPS"""
186
    matrix = np.dot(FLIPXY, matrix)
187
    matrix = np.dot(matrix, FLIPXY)
188
    matrix = np.linalg.inv(matrix)
189
    return matrix
190
191
192
def _from_itk_convention(matrix):
193
    """LPS to RAS"""
194
    matrix = np.dot(matrix, FLIPXY)
195
    matrix = np.dot(FLIPXY, matrix)
196
    matrix = np.linalg.inv(matrix)
197
    return matrix
198
199
200
def _read_itk_matrix(path):
201
    """Read an affine transform in ITK's .tfm format"""
202
    transform = sitk.ReadTransform(str(path))
203
    parameters = transform.GetParameters()
204
    rotation_parameters = parameters[:9]
205
    rotation_matrix = np.array(rotation_parameters).reshape(3, 3)
206
    translation_parameters = parameters[9:]
207
    translation_vector = np.array(translation_parameters).reshape(3, 1)
208
    matrix = np.hstack([rotation_matrix, translation_vector])
209
    homogeneous_matrix_lps = np.vstack([matrix, [0, 0, 0, 1]])
210
    homogeneous_matrix_ras = _from_itk_convention(homogeneous_matrix_lps)
211
    return torch.as_tensor(homogeneous_matrix_ras)
212
213
214
def _write_itk_matrix(matrix, tfm_path):
215
    """The tfm file contains the matrix from floating to reference."""
216
    transform = _matrix_to_itk_transform(matrix)
217
    transform.WriteTransform(str(tfm_path))
218
219
220
def _matrix_to_itk_transform(matrix, dimensions=3):
221
    matrix = _to_itk_convention(matrix)
222
    rotation = matrix[:dimensions, :dimensions].ravel().tolist()
223
    translation = matrix[:dimensions, 3].tolist()
224
    transform = sitk.AffineTransform(rotation, translation)
225
    return transform
226
227
228
def _read_niftyreg_matrix(trsf_path):
229
    """Read a NiftyReg matrix and return it as a NumPy array"""
230
    matrix = np.loadtxt(trsf_path)
231
    matrix = np.linalg.inv(matrix)
232
    return torch.as_tensor(matrix)
233
234
235
def _write_niftyreg_matrix(matrix, txt_path):
236
    """Write an affine transform in NiftyReg's .txt format (ref -> flo)"""
237
    matrix = np.linalg.inv(matrix)
238
    np.savetxt(txt_path, matrix, fmt='%.8f')
239
240
241
def get_rotation_and_spacing_from_affine(
242
        affine: np.ndarray,
243
        ) -> Tuple[np.ndarray, np.ndarray]:
244
    # From https://github.com/nipy/nibabel/blob/master/nibabel/orientations.py
245
    rotation_zoom = affine[:3, :3]
246
    spacing = np.sqrt(np.sum(rotation_zoom * rotation_zoom, axis=0))
247
    rotation = rotation_zoom / spacing
248
    return rotation, spacing
249
250
251
def nib_to_sitk(
252
        data: TypeData,
253
        affine: TypeData,
254
        force_3d: bool = False,
255
        force_4d: bool = False,
256
        ) -> sitk.Image:
257
    """Create a SimpleITK image from a tensor and a 4x4 affine matrix."""
258
    if data.ndim != 4:
259
        shape = tuple(data.shape)
260
        raise ValueError(f'Input must be 4D, but has shape {shape}')
261
    # Possibilities
262
    # (1, w, h, 1)
263
    # (c, w, h, 1)
264
    # (1, w, h, 1)
265
    # (c, w, h, d)
266
    array = np.asarray(data)
267
    affine = np.asarray(affine).astype(np.float64)
268
269
    is_multichannel = array.shape[0] > 1 and not force_4d
270
    is_2d = array.shape[3] == 1 and not force_3d
271
    if is_2d:
272
        array = array[..., 0]
273
    if not is_multichannel and not force_4d:
274
        array = array[0]
275
    array = array.transpose()  # (W, H, D, C) or (W, H, D)
276
    image = sitk.GetImageFromArray(array, isVector=is_multichannel)
277
278
    rotation, spacing = get_rotation_and_spacing_from_affine(affine)
279
    flip_xy = np.diag((-1, -1, 1))  # used to switch between LPS and RAS
280
    origin = np.dot(flip_xy, affine[:3, 3])
281
    direction = np.dot(flip_xy, rotation)
282
    if is_2d:  # ignore first dimension if 2D (1, W, H, 1)
283
        direction = direction[:2, :2]
284
    image.SetOrigin(origin)  # should I add a 4th value if force_4d?
285
    image.SetSpacing(spacing)
286
    image.SetDirection(direction.flatten())
287
    if data.ndim == 4:
288
        assert image.GetNumberOfComponentsPerPixel() == data.shape[0]
289
    num_spatial_dims = 2 if is_2d else 3
290
    assert image.GetSize() == data.shape[1: 1 + num_spatial_dims]
291
    return image
292
293
294
def sitk_to_nib(
295
        image: sitk.Image,
296
        keepdim: bool = False,
297
        ) -> Tuple[np.ndarray, np.ndarray]:
298
    data = sitk.GetArrayFromImage(image).transpose()
299
    data = check_uint_to_int(data)
300
    num_components = image.GetNumberOfComponentsPerPixel()
301
    if num_components == 1:
302
        data = data[np.newaxis]  # add channels dimension
303
    input_spatial_dims = image.GetDimension()
304
    if input_spatial_dims == 2:
305
        data = data[..., np.newaxis]
306
    if not keepdim:
307
        data = ensure_4d(data, num_spatial_dims=input_spatial_dims)
308
    assert data.shape[0] == num_components
309
    assert data.shape[1: 1 + input_spatial_dims] == image.GetSize()
310
    affine = get_ras_affine_from_sitk_metadata(
311
        image.GetSpacing(),
312
        image.GetDirection(),
313
        image.GetOrigin(),
314
    )
315
    return data, affine
316
317
318
def get_ras_affine_from_sitk_metadata(
319
        spacing: TypeTripletInt,
320
        direction_lps: TypeTripletInt,
321
        origin_lps: TypeTripletInt,
322
        ) -> np.ndarray:
323
    spacing = np.array(spacing)
324
    direction = np.array(direction_lps)
325
    if len(direction) == 9:
326
        rotation = direction.reshape(3, 3)
327
    elif len(direction) == 4:  # ignore first dimension if 2D (1, W, H, 1)
328
        rotation_2d = direction.reshape(2, 2)
329
        rotation = np.eye(3)
330
        rotation[:2, :2] = rotation_2d
331
        spacing = *spacing, 1
332
        origin_lps = *origin_lps, 0
333
    else:
334
        raise RuntimeError(f'Direction not understood: {direction}')
335
    flip_xy = np.diag((-1, -1, 1))  # used to switch between LPS and RAS
336
    rotation = np.dot(flip_xy, rotation)
337
    rotation_zoom = rotation * spacing
338
    translation = np.dot(flip_xy, origin_lps)
339
    affine = np.eye(4)
340
    affine[:3, :3] = rotation_zoom
341
    affine[:3, 3] = translation
342
    return affine
343
344
345
def ensure_4d(tensor: TypeData, num_spatial_dims=None) -> TypeData:
346
    # I wish named tensors were properly supported in PyTorch
347
    tensor = torch.as_tensor(tensor)
348
    num_dimensions = tensor.ndim
349
    if num_dimensions == 4:
350
        pass
351
    elif num_dimensions == 5:  # hope (W, H, D, 1, C)
352
        if tensor.shape[-2] == 1:
353
            tensor = tensor[..., 0, :]
354
            tensor = tensor.permute(3, 0, 1, 2)
355
        else:
356
            raise ValueError('5D is not supported for shape[-2] > 1')
357
    elif num_dimensions == 2:  # assume 2D monochannel (W, H)
358
        tensor = tensor[np.newaxis, ..., np.newaxis]  # (1, W, H, 1)
359
    elif num_dimensions == 3:  # 2D multichannel or 3D monochannel?
360
        if num_spatial_dims == 2:
361
            tensor = tensor[..., np.newaxis]  # (C, W, H, 1)
362
        elif num_spatial_dims == 3:  # (W, H, D)
363
            tensor = tensor[np.newaxis]  # (1, W, H, D)
364
        else:  # try to guess
365
            shape = tensor.shape
366
            maybe_rgb = 3 in (shape[0], shape[-1])
367
            if maybe_rgb:
368
                if shape[-1] == 3:  # (W, H, 3)
369
                    tensor = tensor.permute(2, 0, 1)  # (3, W, H)
370
                tensor = tensor[..., np.newaxis]  # (3, W, H, 1)
371
            else:  # (W, H, D)
372
                tensor = tensor[np.newaxis]  # (1, W, H, D)
373
    else:
374
        message = (
375
            f'{num_dimensions}D images not supported yet. Please create an'
376
            f' issue in {REPO_URL} if you would like support for them'
377
        )
378
        raise ValueError(message)
379
    assert tensor.ndim == 4
380
    return tensor
381
382
383
def check_uint_to_int(array):
384
    # This is because PyTorch won't take uint16 nor uint32
385
    if array.dtype == np.uint16:
386
        return array.astype(np.int32)
387
    if array.dtype == np.uint32:
388
        return array.astype(np.int64)
389
    return array
390