Passed
Pull Request — master (#640)
by Fernando
01:22
created

torchio.data.io.check_uint_to_int()   A

Complexity

Conditions 3

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 6
nop 1
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Tuple, Union
4
5
import torch
6
import numpy as np
7
import nibabel as nib
8
import SimpleITK as sitk
9
10
from ..constants import REPO_URL
11
from ..typing import TypePath, TypeData, TypeTripletFloat
12
13
14
FLIPXY = np.diag([-1, -1, 1, 1])
15
16
17
def read_image(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
18
    try:
19
        result = _read_sitk(path)
20
    except RuntimeError:  # try with NiBabel
21
        try:
22
            result = _read_nibabel(path)
23
        except nib.loadsave.ImageFileError:
24
            message = (
25
                f'File "{path}" not understood.'
26
                ' Check supported formats by at'
27
                ' https://simpleitk.readthedocs.io/en/master/IO.html#images'
28
                ' and https://nipy.org/nibabel/api.html#file-formats'
29
            )
30
            raise RuntimeError(message)
31
    return result
32
33
34
def _read_nibabel(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
35
    img = nib.load(str(path), mmap=False)
36
    data = img.get_fdata(dtype=np.float32)
37
    if data.ndim == 5:
38
        data = data[..., 0, :]
39
        data = data.transpose(3, 0, 1, 2)
40
    data = check_uint_to_int(data)
41
    tensor = torch.as_tensor(data)
42
    affine = img.affine
43
    return tensor, affine
44
45
46
def _read_sitk(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
47
    if Path(path).is_dir():  # assume DICOM
48
        image = _read_dicom(path)
49
    else:
50
        image = sitk.ReadImage(str(path))
51
    data, affine = sitk_to_nib(image, keepdim=True)
52
    data = check_uint_to_int(data)
53
    tensor = torch.as_tensor(data)
54
    return tensor, affine
55
56
57
def _read_dicom(directory: TypePath):
58
    directory = Path(directory)
59
    if not directory.is_dir():  # unreachable if called from _read_sitk
60
        raise FileNotFoundError(f'Directory "{directory}" not found')
61
    reader = sitk.ImageSeriesReader()
62
    dicom_names = reader.GetGDCMSeriesFileNames(str(directory))
63
    if not dicom_names:
64
        message = (
65
            f'The directory "{directory}"'
66
            ' does not seem to contain DICOM files'
67
        )
68
        raise FileNotFoundError(message)
69
    reader.SetFileNames(dicom_names)
70
    image = reader.Execute()
71
    return image
72
73
74
def read_shape(path: TypePath) -> Tuple[int, int, int, int]:
75
    reader = sitk.ImageFileReader()
76
    reader.SetFileName(str(path))
77
    reader.ReadImageInformation()
78
    num_channels = reader.GetNumberOfComponents()
79
    spatial_shape = reader.GetSize()
80
    if reader.GetDimension() == 2:
81
        spatial_shape = *spatial_shape, 1
82
    return (num_channels,) + spatial_shape
83
84
85
def read_affine(path: TypePath) -> np.ndarray:
86
    reader = get_reader(path)
87
    affine = get_ras_affine_from_sitk(reader)
88
    return affine
89
90
91
def get_reader(path: TypePath, read: bool = True) -> sitk.ImageFileReader:
92
    reader = sitk.ImageFileReader()
93
    reader.SetFileName(str(path))
94
    if read:
95
        reader.ReadImageInformation()
96
    return reader
97
98
99
def write_image(
100
        tensor: torch.Tensor,
101
        affine: TypeData,
102
        path: TypePath,
103
        squeeze: bool = True,
104
        ) -> None:
105
    args = tensor, affine, path
106
    try:
107
        _write_sitk(*args)
108
    except RuntimeError:  # try with NiBabel
109
        _write_nibabel(*args, squeeze=squeeze)
110
111
112
def _write_nibabel(
113
        tensor: TypeData,
114
        affine: TypeData,
115
        path: TypePath,
116
        squeeze: bool = False,
117
        ) -> None:
118
    """
119
    Expects a path with an extension that can be used by nibabel.save
120
    to write a NIfTI-1 image, such as '.nii.gz' or '.img'
121
    """
122
    assert tensor.ndim == 4
123
    num_components = tensor.shape[0]
124
125
    # NIfTI components must be at the end, in a 5D array
126
    if num_components == 1:
127
        tensor = tensor[0]
128
    else:
129
        tensor = tensor[np.newaxis].permute(2, 3, 4, 0, 1)
130
    tensor = tensor.squeeze() if squeeze else tensor
131
    suffix = Path(str(path).replace('.gz', '')).suffix
132
    if '.nii' in suffix:
133
        img = nib.Nifti1Image(np.asarray(tensor), affine)
134
    elif '.hdr' in suffix or '.img' in suffix:
135
        img = nib.Nifti1Pair(np.asarray(tensor), affine)
136
    else:
137
        raise nib.loadsave.ImageFileError
138
    if num_components > 1:
139
        img.header.set_intent('vector')
140
    img.header['qform_code'] = 1
141
    img.header['sform_code'] = 0
142
    nib.save(img, str(path))
143
144
145
def _write_sitk(
146
        tensor: torch.Tensor,
147
        affine: TypeData,
148
        path: TypePath,
149
        use_compression: bool = True,
150
        ) -> None:
151
    assert tensor.ndim == 4
152
    path = Path(path)
153
    if path.suffix in ('.png', '.jpg', '.jpeg'):
154
        warnings.warn(
155
            f'Casting to uint 8 before saving to {path}',
156
            RuntimeWarning,
157
        )
158
        tensor = tensor.numpy().astype(np.uint8)
159
    image = nib_to_sitk(tensor, affine)
160
    sitk.WriteImage(image, str(path), use_compression)
161
162
163
def read_matrix(path: TypePath):
164
    """Read an affine transform and convert to tensor."""
165
    path = Path(path)
166
    suffix = path.suffix
167
    if suffix in ('.tfm', '.h5'):  # ITK
168
        tensor = _read_itk_matrix(path)
169
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
170
        tensor = _read_niftyreg_matrix(path)
171
    else:
172
        raise ValueError(f'Unknown suffix for transform file: "{suffix}"')
173
    return tensor
174
175
176
def write_matrix(matrix: torch.Tensor, path: TypePath):
177
    """Write an affine transform."""
178
    path = Path(path)
179
    suffix = path.suffix
180
    if suffix in ('.tfm', '.h5'):  # ITK
181
        _write_itk_matrix(matrix, path)
182
    elif suffix in ('.txt', '.trsf'):  # NiftyReg, blockmatching
183
        _write_niftyreg_matrix(matrix, path)
184
185
186
def _to_itk_convention(matrix):
187
    """RAS to LPS"""
188
    matrix = np.dot(FLIPXY, matrix)
189
    matrix = np.dot(matrix, FLIPXY)
190
    matrix = np.linalg.inv(matrix)
191
    return matrix
192
193
194
def _from_itk_convention(matrix):
195
    """LPS to RAS"""
196
    matrix = np.dot(matrix, FLIPXY)
197
    matrix = np.dot(FLIPXY, matrix)
198
    matrix = np.linalg.inv(matrix)
199
    return matrix
200
201
202
def _read_itk_matrix(path):
203
    """Read an affine transform in ITK's .tfm format"""
204
    transform = sitk.ReadTransform(str(path))
205
    parameters = transform.GetParameters()
206
    rotation_parameters = parameters[:9]
207
    rotation_matrix = np.array(rotation_parameters).reshape(3, 3)
208
    translation_parameters = parameters[9:]
209
    translation_vector = np.array(translation_parameters).reshape(3, 1)
210
    matrix = np.hstack([rotation_matrix, translation_vector])
211
    homogeneous_matrix_lps = np.vstack([matrix, [0, 0, 0, 1]])
212
    homogeneous_matrix_ras = _from_itk_convention(homogeneous_matrix_lps)
213
    return torch.as_tensor(homogeneous_matrix_ras)
214
215
216
def _write_itk_matrix(matrix, tfm_path):
217
    """The tfm file contains the matrix from floating to reference."""
218
    transform = _matrix_to_itk_transform(matrix)
219
    transform.WriteTransform(str(tfm_path))
220
221
222
def _matrix_to_itk_transform(matrix, dimensions=3):
223
    matrix = _to_itk_convention(matrix)
224
    rotation = matrix[:dimensions, :dimensions].ravel().tolist()
225
    translation = matrix[:dimensions, 3].tolist()
226
    transform = sitk.AffineTransform(rotation, translation)
227
    return transform
228
229
230
def _read_niftyreg_matrix(trsf_path):
231
    """Read a NiftyReg matrix and return it as a NumPy array"""
232
    matrix = np.loadtxt(trsf_path)
233
    matrix = np.linalg.inv(matrix)
234
    return torch.as_tensor(matrix)
235
236
237
def _write_niftyreg_matrix(matrix, txt_path):
238
    """Write an affine transform in NiftyReg's .txt format (ref -> flo)"""
239
    matrix = np.linalg.inv(matrix)
240
    np.savetxt(txt_path, matrix, fmt='%.8f')
241
242
243
def get_rotation_and_spacing_from_affine(
244
        affine: np.ndarray,
245
        ) -> Tuple[np.ndarray, np.ndarray]:
246
    # From https://github.com/nipy/nibabel/blob/master/nibabel/orientations.py
247
    rotation_zoom = affine[:3, :3]
248
    spacing = np.sqrt(np.sum(rotation_zoom * rotation_zoom, axis=0))
249
    rotation = rotation_zoom / spacing
250
    return rotation, spacing
251
252
253
def nib_to_sitk(
254
        data: TypeData,
255
        affine: TypeData,
256
        force_3d: bool = False,
257
        force_4d: bool = False,
258
        ) -> sitk.Image:
259
    """Create a SimpleITK image from a tensor and a 4x4 affine matrix."""
260
    if data.ndim != 4:
261
        shape = tuple(data.shape)
262
        raise ValueError(f'Input must be 4D, but has shape {shape}')
263
    # Possibilities
264
    # (1, w, h, 1)
265
    # (c, w, h, 1)
266
    # (1, w, h, 1)
267
    # (c, w, h, d)
268
    array = np.asarray(data)
269
    affine = np.asarray(affine).astype(np.float64)
270
271
    is_multichannel = array.shape[0] > 1 and not force_4d
272
    is_2d = array.shape[3] == 1 and not force_3d
273
    if is_2d:
274
        array = array[..., 0]
275
    if not is_multichannel and not force_4d:
276
        array = array[0]
277
    array = array.transpose()  # (W, H, D, C) or (W, H, D)
278
    image = sitk.GetImageFromArray(array, isVector=is_multichannel)
279
280
    origin, spacing, direction = get_sitk_metadata_from_ras_affine(
281
        affine,
282
        is_2d=is_2d,
283
    )
284
    image.SetOrigin(origin)  # should I add a 4th value if force_4d?
285
    image.SetSpacing(spacing)
286
    image.SetDirection(direction)
287
288
    if data.ndim == 4:
289
        assert image.GetNumberOfComponentsPerPixel() == data.shape[0]
290
    num_spatial_dims = 2 if is_2d else 3
291
    assert image.GetSize() == data.shape[1: 1 + num_spatial_dims]
292
293
    return image
294
295
296
def sitk_to_nib(
297
        image: sitk.Image,
298
        keepdim: bool = False,
299
        ) -> Tuple[np.ndarray, np.ndarray]:
300
    data = sitk.GetArrayFromImage(image).transpose()
301
    data = check_uint_to_int(data)
302
    num_components = image.GetNumberOfComponentsPerPixel()
303
    if num_components == 1:
304
        data = data[np.newaxis]  # add channels dimension
305
    input_spatial_dims = image.GetDimension()
306
    if input_spatial_dims == 2:
307
        data = data[..., np.newaxis]
308
    if not keepdim:
309
        data = ensure_4d(data, num_spatial_dims=input_spatial_dims)
310
    assert data.shape[0] == num_components
311
    assert data.shape[1: 1 + input_spatial_dims] == image.GetSize()
312
    affine = get_ras_affine_from_sitk(image)
313
    return data, affine
314
315
316
def get_ras_affine_from_sitk(
317
        sitk_object: Union[sitk.Image, sitk.ImageFileReader]
318
        ) -> np.ndarray:
319
    spacing = np.array(sitk_object.GetSpacing())
320
    direction_lps = np.array(sitk_object.GetDirection())
321
    origin_lps = np.array(sitk_object.GetOrigin())
322
    if len(direction_lps) == 9:
323
        rotation = direction_lps.reshape(3, 3)
324
    elif len(direction_lps) == 4:  # ignore first dimension if 2D (1, W, H, 1)
325
        rotation_2d = direction_lps.reshape(2, 2)
326
        rotation = np.eye(3)
327
        rotation[:2, :2] = rotation_2d
328
        spacing = *spacing, 1
329
        origin_lps = *origin_lps, 0
330
    else:
331
        raise RuntimeError(f'Direction not understood: {direction_lps}')
332
    flip_xy = np.diag((-1, -1, 1))  # used to switch between LPS and RAS
333
    rotation = np.dot(flip_xy, rotation)
334
    rotation_zoom = rotation * spacing
335
    translation = np.dot(flip_xy, origin_lps)
336
    affine = np.eye(4)
337
    affine[:3, :3] = rotation_zoom
338
    affine[:3, 3] = translation
339
    return affine
340
341
342
def get_sitk_metadata_from_ras_affine(
343
        affine: np.ndarray,
344
        is_2d: bool = False,
345
        ) -> Tuple[TypeTripletFloat, TypeTripletFloat, Tuple[float, ...]]:
346
    rotation, spacing = get_rotation_and_spacing_from_affine(affine)
347
    flip_xy = np.diag((-1, -1, 1))  # used to switch between LPS and RAS
348
    origin = np.dot(flip_xy, affine[:3, 3])
349
    direction = np.dot(flip_xy, rotation)
350
    if is_2d:  # ignore first dimension if 2D (1, W, H, 1)
351
        direction = direction[:2, :2]
352
    direction = tuple(direction.flatten())
353
    return origin, spacing, direction
354
355
356
def ensure_4d(tensor: TypeData, num_spatial_dims=None) -> TypeData:
357
    # I wish named tensors were properly supported in PyTorch
358
    tensor = torch.as_tensor(tensor)
359
    num_dimensions = tensor.ndim
360
    if num_dimensions == 4:
361
        pass
362
    elif num_dimensions == 5:  # hope (W, H, D, 1, C)
363
        if tensor.shape[-2] == 1:
364
            tensor = tensor[..., 0, :]
365
            tensor = tensor.permute(3, 0, 1, 2)
366
        else:
367
            raise ValueError('5D is not supported for shape[-2] > 1')
368
    elif num_dimensions == 2:  # assume 2D monochannel (W, H)
369
        tensor = tensor[np.newaxis, ..., np.newaxis]  # (1, W, H, 1)
370
    elif num_dimensions == 3:  # 2D multichannel or 3D monochannel?
371
        if num_spatial_dims == 2:
372
            tensor = tensor[..., np.newaxis]  # (C, W, H, 1)
373
        elif num_spatial_dims == 3:  # (W, H, D)
374
            tensor = tensor[np.newaxis]  # (1, W, H, D)
375
        else:  # try to guess
376
            shape = tensor.shape
377
            maybe_rgb = 3 in (shape[0], shape[-1])
378
            if maybe_rgb:
379
                if shape[-1] == 3:  # (W, H, 3)
380
                    tensor = tensor.permute(2, 0, 1)  # (3, W, H)
381
                tensor = tensor[..., np.newaxis]  # (3, W, H, 1)
382
            else:  # (W, H, D)
383
                tensor = tensor[np.newaxis]  # (1, W, H, D)
384
    else:
385
        message = (
386
            f'{num_dimensions}D images not supported yet. Please create an'
387
            f' issue in {REPO_URL} if you would like support for them'
388
        )
389
        raise ValueError(message)
390
    assert tensor.ndim == 4
391
    return tensor
392
393
394
def check_uint_to_int(array):
395
    # This is because PyTorch won't take uint16 nor uint32
396
    if array.dtype == np.uint16:
397
        return array.astype(np.int32)
398
    if array.dtype == np.uint32:
399
        return array.astype(np.int64)
400
    return array
401