1
|
|
|
import warnings |
2
|
|
|
from pathlib import Path |
3
|
|
|
from collections.abc import Iterable |
4
|
|
|
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List, Callable |
5
|
|
|
|
6
|
|
|
import torch |
7
|
|
|
import humanize |
8
|
|
|
import numpy as np |
9
|
|
|
import nibabel as nib |
10
|
|
|
import SimpleITK as sitk |
11
|
|
|
from deprecated import deprecated |
12
|
|
|
|
13
|
|
|
from ..utils import get_stem |
14
|
|
|
from ..typing import ( |
15
|
|
|
TypeData, |
16
|
|
|
TypePath, |
17
|
|
|
TypeTripletInt, |
18
|
|
|
TypeTripletFloat, |
19
|
|
|
TypeDirection3D, |
20
|
|
|
) |
21
|
|
|
from ..constants import DATA, TYPE, AFFINE, PATH, STEM, INTENSITY, LABEL |
22
|
|
|
from .io import ( |
23
|
|
|
ensure_4d, |
24
|
|
|
read_image, |
25
|
|
|
write_image, |
26
|
|
|
nib_to_sitk, |
27
|
|
|
sitk_to_nib, |
28
|
|
|
check_uint_to_int, |
29
|
|
|
get_rotation_and_spacing_from_affine, |
30
|
|
|
get_sitk_metadata_from_ras_affine, |
31
|
|
|
read_shape, |
32
|
|
|
read_affine, |
33
|
|
|
) |
34
|
|
|
|
35
|
|
|
|
36
|
|
|
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM |
37
|
|
|
TypeBound = Tuple[float, float] |
38
|
|
|
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound] |
39
|
|
|
|
40
|
|
|
deprecation_message = ( |
41
|
|
|
'Setting the image data with the property setter is deprecated. Use the' |
42
|
|
|
' set_data() method instead' |
43
|
|
|
) |
44
|
|
|
|
45
|
|
|
|
46
|
|
|
class Image(dict): |
47
|
|
|
r"""TorchIO image. |
48
|
|
|
|
49
|
|
|
For information about medical image orientation, check out `NiBabel docs`_, |
50
|
|
|
the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or |
51
|
|
|
`SimpleITK docs`_. |
52
|
|
|
|
53
|
|
|
Args: |
54
|
|
|
path: Path to a file or sequence of paths to files that can be read by |
55
|
|
|
:mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing |
56
|
|
|
DICOM files. If :attr:`tensor` is given, the data in |
57
|
|
|
:attr:`path` will not be read. |
58
|
|
|
If a sequence of paths is given, data |
59
|
|
|
will be concatenated on the channel dimension so spatial |
60
|
|
|
dimensions must match. |
61
|
|
|
type: Type of image, such as :attr:`torchio.INTENSITY` or |
62
|
|
|
:attr:`torchio.LABEL`. This will be used by the transforms to |
63
|
|
|
decide whether to apply an operation, or which interpolation to use |
64
|
|
|
when resampling. For example, `preprocessing`_ and `augmentation`_ |
65
|
|
|
intensity transforms will only be applied to images with type |
66
|
|
|
:attr:`torchio.INTENSITY`. Spatial transforms will be applied to |
67
|
|
|
all types, and nearest neighbor interpolation is always used to |
68
|
|
|
resample images with type :attr:`torchio.LABEL`. |
69
|
|
|
The type :attr:`torchio.SAMPLING_MAP` may be used with instances of |
70
|
|
|
:class:`~torchio.data.sampler.weighted.WeightedSampler`. |
71
|
|
|
tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D |
72
|
|
|
:class:`torch.Tensor` or NumPy array with dimensions |
73
|
|
|
:math:`(C, W, H, D)`. |
74
|
|
|
affine: :math:`4 \times 4` matrix to convert voxel coordinates to world |
75
|
|
|
coordinates. If ``None``, an identity matrix will be used. See the |
76
|
|
|
`NiBabel docs on coordinates`_ for more information. |
77
|
|
|
check_nans: If ``True``, issues a warning if NaNs are found |
78
|
|
|
in the image. If ``False``, images will not be checked for the |
79
|
|
|
presence of NaNs. |
80
|
|
|
channels_last: If ``True``, the read tensor will be permuted so the |
81
|
|
|
last dimension becomes the first. This is useful, e.g., when |
82
|
|
|
NIfTI images have been saved with the channels dimension being the |
83
|
|
|
fourth instead of the fifth. |
84
|
|
|
reader: Callable object that takes a path and returns a 4D tensor and a |
85
|
|
|
2D, :math:`4 \times 4` affine matrix. This can be used if your data |
86
|
|
|
is saved in a custom format, such as ``.npy`` (see example below). |
87
|
|
|
If the affine matrix is ``None``, an identity matrix will be used. |
88
|
|
|
**kwargs: Items that will be added to the image dictionary, e.g. |
89
|
|
|
acquisition parameters. |
90
|
|
|
|
91
|
|
|
TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk |
92
|
|
|
when needed. |
93
|
|
|
|
94
|
|
|
Example: |
95
|
|
|
>>> import torchio as tio |
96
|
|
|
>>> import numpy as np |
97
|
|
|
>>> image = tio.ScalarImage('t1.nii.gz') # subclass of Image |
98
|
|
|
>>> image # not loaded yet |
99
|
|
|
ScalarImage(path: t1.nii.gz; type: intensity) |
100
|
|
|
>>> times_two = 2 * image.data # data is loaded and cached here |
101
|
|
|
>>> image |
102
|
|
|
ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity) |
103
|
|
|
>>> image.save('doubled_image.nii.gz') |
104
|
|
|
>>> numpy_reader = lambda path: np.load(path), np.eye(4) |
105
|
|
|
>>> image = tio.ScalarImage('t1.npy', reader=numpy_reader) |
106
|
|
|
|
107
|
|
|
.. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading |
108
|
|
|
.. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity |
109
|
|
|
.. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity |
110
|
|
|
.. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html |
111
|
|
|
.. _NiBabel docs on coordinates: https://nipy.org/nibabel/coordinate_systems.html#the-affine-matrix-as-a-transformation-between-spaces |
112
|
|
|
.. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems |
113
|
|
|
.. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained |
114
|
|
|
.. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html |
115
|
|
|
.. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm |
116
|
|
|
""" |
117
|
|
|
def __init__( |
118
|
|
|
self, |
119
|
|
|
path: Union[TypePath, Sequence[TypePath], None] = None, |
120
|
|
|
type: str = None, |
121
|
|
|
tensor: Optional[TypeData] = None, |
122
|
|
|
affine: Optional[TypeData] = None, |
123
|
|
|
check_nans: bool = False, # removed by ITK by default |
124
|
|
|
channels_last: bool = False, |
125
|
|
|
reader: Callable = read_image, |
126
|
|
|
**kwargs: Dict[str, Any], |
127
|
|
|
): |
128
|
|
|
self.check_nans = check_nans |
129
|
|
|
self.channels_last = channels_last |
130
|
|
|
self.reader = reader |
131
|
|
|
|
132
|
|
|
if type is None: |
133
|
|
|
warnings.warn( |
134
|
|
|
'Not specifying the image type is deprecated and will be' |
135
|
|
|
' mandatory in the future. You can probably use tio.ScalarImage' |
136
|
|
|
' or tio.LabelMap instead', |
137
|
|
|
) |
138
|
|
|
type = INTENSITY |
139
|
|
|
|
140
|
|
|
if path is None and tensor is None: |
141
|
|
|
raise ValueError('A value for path or tensor must be given') |
142
|
|
|
self._loaded = False |
143
|
|
|
|
144
|
|
|
tensor = self._parse_tensor(tensor) |
145
|
|
|
affine = self._parse_affine(affine) |
146
|
|
|
if tensor is not None: |
147
|
|
|
self.set_data(tensor) |
148
|
|
|
self.affine = affine |
149
|
|
|
self._loaded = True |
150
|
|
|
for key in PROTECTED_KEYS: |
151
|
|
|
if key in kwargs: |
152
|
|
|
message = f'Key "{key}" is reserved. Use a different one' |
153
|
|
|
raise ValueError(message) |
154
|
|
|
|
155
|
|
|
super().__init__(**kwargs) |
156
|
|
|
self.path = self._parse_path(path) |
157
|
|
|
|
158
|
|
|
self[PATH] = '' if self.path is None else str(self.path) |
159
|
|
|
self[STEM] = '' if self.path is None else get_stem(self.path) |
160
|
|
|
self[TYPE] = type |
161
|
|
|
|
162
|
|
|
def __repr__(self): |
163
|
|
|
properties = [] |
164
|
|
|
properties.extend([ |
165
|
|
|
f'shape: {self.shape}', |
166
|
|
|
f'spacing: {self.get_spacing_string()}', |
167
|
|
|
f'orientation: {"".join(self.orientation)}+', |
168
|
|
|
]) |
169
|
|
|
if self._loaded: |
170
|
|
|
properties.append(f'dtype: {self.data.type()}') |
171
|
|
|
properties.append(f'memory: {humanize.naturalsize(self.memory, binary=True)}') |
172
|
|
|
else: |
173
|
|
|
properties.append(f'path: "{self.path}"') |
174
|
|
|
|
175
|
|
|
properties = '; '.join(properties) |
176
|
|
|
string = f'{self.__class__.__name__}({properties})' |
177
|
|
|
return string |
178
|
|
|
|
179
|
|
|
def __getitem__(self, item): |
180
|
|
|
if item in (DATA, AFFINE): |
181
|
|
|
if item not in self: |
182
|
|
|
self.load() |
183
|
|
|
return super().__getitem__(item) |
184
|
|
|
|
185
|
|
|
def __array__(self): |
186
|
|
|
return self.data.numpy() |
187
|
|
|
|
188
|
|
|
def __copy__(self): |
189
|
|
|
kwargs = dict( |
190
|
|
|
tensor=self.data, |
191
|
|
|
affine=self.affine, |
192
|
|
|
type=self.type, |
193
|
|
|
path=self.path, |
194
|
|
|
) |
195
|
|
|
for key, value in self.items(): |
196
|
|
|
if key in PROTECTED_KEYS: continue |
197
|
|
|
kwargs[key] = value # should I copy? deepcopy? |
198
|
|
|
return self.__class__(**kwargs) |
199
|
|
|
|
200
|
|
|
@property |
201
|
|
|
def data(self) -> torch.Tensor: |
202
|
|
|
"""Tensor data. Same as :class:`Image.tensor`.""" |
203
|
|
|
return self[DATA] |
204
|
|
|
|
205
|
|
|
@data.setter # type: ignore |
206
|
|
|
@deprecated(version='0.18.16', reason=deprecation_message) |
207
|
|
|
def data(self, tensor: TypeData): |
208
|
|
|
self.set_data(tensor) |
209
|
|
|
|
210
|
|
|
def set_data(self, tensor: TypeData): |
211
|
|
|
"""Store a 4D tensor in the :attr:`data` key and attribute. |
212
|
|
|
|
213
|
|
|
Args: |
214
|
|
|
tensor: 4D tensor with dimensions :math:`(C, W, H, D)`. |
215
|
|
|
""" |
216
|
|
|
self[DATA] = self._parse_tensor(tensor, none_ok=False) |
217
|
|
|
|
218
|
|
|
@property |
219
|
|
|
def tensor(self) -> torch.Tensor: |
220
|
|
|
"""Tensor data. Same as :class:`Image.data`.""" |
221
|
|
|
return self.data |
222
|
|
|
|
223
|
|
|
@property |
224
|
|
|
def affine(self) -> np.ndarray: |
225
|
|
|
"""Affine matrix to transform voxel indices into world coordinates.""" |
226
|
|
|
# If path is a dir (probably DICOM), just load the data |
227
|
|
|
# Same if it's a list of paths (used to create a 4D image) |
228
|
|
|
if self._loaded or (isinstance(self.path, Path) and self.path.is_dir()): |
229
|
|
|
affine = self[AFFINE] |
230
|
|
|
else: |
231
|
|
|
affine = read_affine(self.path) |
232
|
|
|
return affine |
233
|
|
|
|
234
|
|
|
@affine.setter |
235
|
|
|
def affine(self, matrix): |
236
|
|
|
self[AFFINE] = self._parse_affine(matrix) |
237
|
|
|
|
238
|
|
|
@property |
239
|
|
|
def type(self) -> str: |
240
|
|
|
return self[TYPE] |
241
|
|
|
|
242
|
|
|
@property |
243
|
|
|
def shape(self) -> Tuple[int, int, int, int]: |
244
|
|
|
"""Tensor shape as :math:`(C, W, H, D)`.""" |
245
|
|
|
custom_reader = self.reader is not read_image |
246
|
|
|
multipath = not isinstance(self.path, (str, Path)) |
247
|
|
|
if self._loaded or custom_reader or multipath or self.path.is_dir(): |
248
|
|
|
shape = tuple(self.data.shape) |
249
|
|
|
else: |
250
|
|
|
shape = read_shape(self.path) |
251
|
|
|
return shape |
252
|
|
|
|
253
|
|
|
@property |
254
|
|
|
def spatial_shape(self) -> TypeTripletInt: |
255
|
|
|
"""Tensor spatial shape as :math:`(W, H, D)`.""" |
256
|
|
|
return self.shape[1:] |
257
|
|
|
|
258
|
|
|
def check_is_2d(self) -> None: |
259
|
|
|
if not self.is_2d(): |
260
|
|
|
message = f'Image is not 2D. Spatial shape: {self.spatial_shape}' |
261
|
|
|
raise RuntimeError(message) |
262
|
|
|
|
263
|
|
|
@property |
264
|
|
|
def height(self) -> int: |
265
|
|
|
"""Image height, if 2D.""" |
266
|
|
|
self.check_is_2d() |
267
|
|
|
return self.spatial_shape[1] |
268
|
|
|
|
269
|
|
|
@property |
270
|
|
|
def width(self) -> int: |
271
|
|
|
"""Image width, if 2D.""" |
272
|
|
|
self.check_is_2d() |
273
|
|
|
return self.spatial_shape[0] |
274
|
|
|
|
275
|
|
|
@property |
276
|
|
|
def orientation(self) -> Tuple[str, str, str]: |
277
|
|
|
"""Orientation codes.""" |
278
|
|
|
return nib.aff2axcodes(self.affine) |
279
|
|
|
|
280
|
|
|
@property |
281
|
|
|
def direction(self) -> TypeDirection3D: |
282
|
|
|
_, _, direction = get_sitk_metadata_from_ras_affine( |
283
|
|
|
self.affine, lps=False) |
284
|
|
|
return direction |
285
|
|
|
|
286
|
|
|
@property |
287
|
|
|
def spacing(self) -> Tuple[float, float, float]: |
288
|
|
|
"""Voxel spacing in mm.""" |
289
|
|
|
_, spacing = get_rotation_and_spacing_from_affine(self.affine) |
290
|
|
|
return tuple(spacing) |
291
|
|
|
|
292
|
|
|
@property |
293
|
|
|
def origin(self) -> Tuple[float, float, float]: |
294
|
|
|
"""Center of first voxel in array, in mm.""" |
295
|
|
|
return tuple(self.affine[:3, 3]) |
296
|
|
|
|
297
|
|
|
@property |
298
|
|
|
def itemsize(self): |
299
|
|
|
"""Element size of the data type.""" |
300
|
|
|
return self.data.element_size() |
301
|
|
|
|
302
|
|
|
@property |
303
|
|
|
def memory(self) -> float: |
304
|
|
|
"""Number of Bytes that the tensor takes in the RAM.""" |
305
|
|
|
return np.prod(self.shape) * self.itemsize |
306
|
|
|
|
307
|
|
|
@property |
308
|
|
|
def bounds(self) -> np.ndarray: |
309
|
|
|
"""Position of centers of voxels in smallest and largest coordinates.""" |
310
|
|
|
ini = 0, 0, 0 |
311
|
|
|
fin = np.array(self.spatial_shape) - 1 |
312
|
|
|
point_ini = nib.affines.apply_affine(self.affine, ini) |
313
|
|
|
point_fin = nib.affines.apply_affine(self.affine, fin) |
314
|
|
|
return np.array((point_ini, point_fin)) |
315
|
|
|
|
316
|
|
|
@property |
317
|
|
|
def num_channels(self) -> int: |
318
|
|
|
"""Get the number of channels in the associated 4D tensor.""" |
319
|
|
|
return len(self.data) |
320
|
|
|
|
321
|
|
|
def axis_name_to_index(self, axis: str) -> int: |
322
|
|
|
"""Convert an axis name to an axis index. |
323
|
|
|
|
324
|
|
|
Args: |
325
|
|
|
axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``, |
326
|
|
|
``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case |
327
|
|
|
versions and first letters are also valid, as only the first |
328
|
|
|
letter will be used. |
329
|
|
|
|
330
|
|
|
.. note:: If you are working with animals, you should probably use |
331
|
|
|
``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'`` |
332
|
|
|
for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``, |
333
|
|
|
respectively. |
334
|
|
|
|
335
|
|
|
.. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``, |
336
|
|
|
``'Left'`` and ``'Right'``. |
337
|
|
|
""" |
338
|
|
|
# Top and bottom are used for the vertical 2D axis as the use of |
339
|
|
|
# Height vs Horizontal might be ambiguous |
340
|
|
|
|
341
|
|
|
if not isinstance(axis, str): |
342
|
|
|
raise ValueError('Axis must be a string') |
343
|
|
|
axis = axis[0].upper() |
344
|
|
|
|
345
|
|
|
# Generally, TorchIO tensors are (C, W, H, D) |
346
|
|
|
if axis in 'TB': # Top, Bottom |
347
|
|
|
return -2 |
348
|
|
|
else: |
349
|
|
|
try: |
350
|
|
|
index = self.orientation.index(axis) |
351
|
|
|
except ValueError: |
352
|
|
|
index = self.orientation.index(self.flip_axis(axis)) |
353
|
|
|
# Return negative indices so that it does not matter whether we |
354
|
|
|
# refer to spatial dimensions or not |
355
|
|
|
index = -3 + index |
356
|
|
|
return index |
357
|
|
|
|
358
|
|
|
# flake8: noqa: E701 |
359
|
|
|
@staticmethod |
360
|
|
|
def flip_axis(axis: str) -> str: |
361
|
|
|
if axis == 'R': flipped_axis = 'L' |
362
|
|
|
elif axis == 'L': flipped_axis = 'R' |
363
|
|
|
elif axis == 'A': flipped_axis = 'P' |
364
|
|
|
elif axis == 'P': flipped_axis = 'A' |
365
|
|
|
elif axis == 'I': flipped_axis = 'S' |
366
|
|
|
elif axis == 'S': flipped_axis = 'I' |
367
|
|
|
elif axis == 'T': flipped_axis = 'B' # top / bottom |
368
|
|
|
elif axis == 'B': flipped_axis = 'T' |
369
|
|
|
else: |
370
|
|
|
values = ', '.join('LRPAISTB') |
371
|
|
|
message = f'Axis not understood. Please use one of: {values}' |
372
|
|
|
raise ValueError(message) |
373
|
|
|
return flipped_axis |
374
|
|
|
|
375
|
|
|
def get_spacing_string(self) -> str: |
376
|
|
|
strings = [f'{n:.2f}' for n in self.spacing] |
377
|
|
|
string = f'({", ".join(strings)})' |
378
|
|
|
return string |
379
|
|
|
|
380
|
|
|
def get_bounds(self) -> TypeBounds: |
381
|
|
|
"""Get minimum and maximum world coordinates occupied by the image.""" |
382
|
|
|
first_index = 3 * (-0.5,) |
383
|
|
|
last_index = np.array(self.spatial_shape) - 0.5 |
384
|
|
|
first_point = nib.affines.apply_affine(self.affine, first_index) |
385
|
|
|
last_point = nib.affines.apply_affine(self.affine, last_index) |
386
|
|
|
array = np.array((first_point, last_point)) |
387
|
|
|
bounds_x, bounds_y, bounds_z = array.T.tolist() |
388
|
|
|
return bounds_x, bounds_y, bounds_z |
389
|
|
|
|
390
|
|
|
@staticmethod |
391
|
|
|
def _parse_single_path( |
392
|
|
|
path: TypePath |
393
|
|
|
) -> Path: |
394
|
|
|
try: |
395
|
|
|
path = Path(path).expanduser() |
396
|
|
|
except TypeError: |
397
|
|
|
message = ( |
398
|
|
|
f'Expected type str or Path but found {path} with type' |
399
|
|
|
f' {type(path)} instead' |
400
|
|
|
) |
401
|
|
|
raise TypeError(message) |
402
|
|
|
except RuntimeError: |
403
|
|
|
message = ( |
404
|
|
|
f'Conversion to path not possible for variable: {path}' |
405
|
|
|
) |
406
|
|
|
raise RuntimeError(message) |
407
|
|
|
|
408
|
|
|
if not (path.is_file() or path.is_dir()): # might be a dir with DICOM |
409
|
|
|
raise FileNotFoundError(f'File not found: "{path}"') |
410
|
|
|
return path |
411
|
|
|
|
412
|
|
|
def _parse_path( |
413
|
|
|
self, |
414
|
|
|
path: Union[TypePath, Sequence[TypePath], None] |
415
|
|
|
) -> Optional[Union[Path, List[Path]]]: |
416
|
|
|
if path is None: |
417
|
|
|
return None |
418
|
|
|
if isinstance(path, Iterable) and not isinstance(path, str): |
419
|
|
|
return [self._parse_single_path(p) for p in path] |
420
|
|
|
else: |
421
|
|
|
return self._parse_single_path(path) |
422
|
|
|
|
423
|
|
|
def _parse_tensor( |
424
|
|
|
self, |
425
|
|
|
tensor: Optional[TypeData], |
426
|
|
|
none_ok: bool = True, |
427
|
|
|
) -> Optional[torch.Tensor]: |
428
|
|
|
if tensor is None: |
429
|
|
|
if none_ok: |
430
|
|
|
return None |
431
|
|
|
else: |
432
|
|
|
raise RuntimeError('Input tensor cannot be None') |
433
|
|
|
if isinstance(tensor, np.ndarray): |
434
|
|
|
tensor = check_uint_to_int(tensor) |
435
|
|
|
tensor = torch.as_tensor(tensor) |
436
|
|
|
elif not isinstance(tensor, torch.Tensor): |
437
|
|
|
message = ( |
438
|
|
|
'Input tensor must be a PyTorch tensor or NumPy array,' |
439
|
|
|
f' but type "{type(tensor)}" was found' |
440
|
|
|
) |
441
|
|
|
raise TypeError(message) |
442
|
|
|
ndim = tensor.ndim |
443
|
|
|
if ndim != 4: |
444
|
|
|
raise ValueError(f'Input tensor must be 4D, but it is {ndim}D') |
445
|
|
|
if tensor.dtype == torch.bool: |
446
|
|
|
tensor = tensor.to(torch.uint8) |
447
|
|
|
if self.check_nans and torch.isnan(tensor).any(): |
448
|
|
|
warnings.warn(f'NaNs found in tensor', RuntimeWarning) |
449
|
|
|
return tensor |
450
|
|
|
|
451
|
|
|
@staticmethod |
452
|
|
|
def _parse_tensor_shape(tensor: torch.Tensor) -> TypeData: |
453
|
|
|
return ensure_4d(tensor) |
454
|
|
|
|
455
|
|
|
@staticmethod |
456
|
|
|
def _parse_affine(affine: Optional[TypeData]) -> np.ndarray: |
457
|
|
|
if affine is None: |
458
|
|
|
return np.eye(4) |
459
|
|
|
if isinstance(affine, torch.Tensor): |
460
|
|
|
affine = affine.numpy() |
461
|
|
|
if not isinstance(affine, np.ndarray): |
462
|
|
|
raise TypeError(f'Affine must be a NumPy array, not {type(affine)}') |
463
|
|
|
if affine.shape != (4, 4): |
464
|
|
|
raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}') |
465
|
|
|
return affine.astype(np.float64) |
466
|
|
|
|
467
|
|
|
def load(self) -> None: |
468
|
|
|
r"""Load the image from disk. |
469
|
|
|
|
470
|
|
|
Returns: |
471
|
|
|
Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D |
472
|
|
|
:math:`4 \times 4` affine matrix to convert voxel indices to world |
473
|
|
|
coordinates. |
474
|
|
|
""" |
475
|
|
|
if self._loaded: |
476
|
|
|
return |
477
|
|
|
paths = self.path if isinstance(self.path, list) else [self.path] |
478
|
|
|
tensor, affine = self.read_and_check(paths[0]) |
479
|
|
|
tensors = [tensor] |
480
|
|
|
for path in paths[1:]: |
481
|
|
|
new_tensor, new_affine = self.read_and_check(path) |
482
|
|
|
if not np.array_equal(affine, new_affine): |
483
|
|
|
message = ( |
484
|
|
|
'Files have different affine matrices.' |
485
|
|
|
f'\nMatrix of {paths[0]}:' |
486
|
|
|
f'\n{affine}' |
487
|
|
|
f'\nMatrix of {path}:' |
488
|
|
|
f'\n{new_affine}' |
489
|
|
|
) |
490
|
|
|
warnings.warn(message, RuntimeWarning) |
491
|
|
|
if not tensor.shape[1:] == new_tensor.shape[1:]: |
492
|
|
|
message = ( |
493
|
|
|
f'Files shape do not match, found {tensor.shape}' |
494
|
|
|
f'and {new_tensor.shape}' |
495
|
|
|
) |
496
|
|
|
RuntimeError(message) |
497
|
|
|
tensors.append(new_tensor) |
498
|
|
|
tensor = torch.cat(tensors) |
499
|
|
|
self.set_data(tensor) |
500
|
|
|
self.affine = affine |
501
|
|
|
self._loaded = True |
502
|
|
|
|
503
|
|
|
def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]: |
504
|
|
|
tensor, affine = self.reader(path) |
505
|
|
|
tensor = self._parse_tensor_shape(tensor) |
506
|
|
|
tensor = self._parse_tensor(tensor) |
507
|
|
|
affine = self._parse_affine(affine) |
508
|
|
|
if self.channels_last: |
509
|
|
|
tensor = tensor.permute(3, 0, 1, 2) |
510
|
|
|
if self.check_nans and torch.isnan(tensor).any(): |
511
|
|
|
warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning) |
512
|
|
|
return tensor, affine |
513
|
|
|
|
514
|
|
|
def save(self, path: TypePath, squeeze: Optional[bool] = None) -> None: |
515
|
|
|
"""Save image to disk. |
516
|
|
|
|
517
|
|
|
Args: |
518
|
|
|
path: String or instance of :class:`pathlib.Path`. |
519
|
|
|
squeeze: Whether to remove singleton dimensions before saving. |
520
|
|
|
If ``None``, the array will be squeezed if the output format is |
521
|
|
|
JP(E)G, PNG, BMP or TIF(F). |
522
|
|
|
""" |
523
|
|
|
write_image( |
524
|
|
|
self.data, |
525
|
|
|
self.affine, |
526
|
|
|
path, |
527
|
|
|
squeeze=squeeze, |
528
|
|
|
) |
529
|
|
|
|
530
|
|
|
def is_2d(self) -> bool: |
531
|
|
|
return self.shape[-1] == 1 |
532
|
|
|
|
533
|
|
|
def numpy(self) -> np.ndarray: |
534
|
|
|
"""Get a NumPy array containing the image data.""" |
535
|
|
|
return np.asarray(self) |
536
|
|
|
|
537
|
|
|
def as_sitk(self, **kwargs) -> sitk.Image: |
538
|
|
|
"""Get the image as an instance of :class:`sitk.Image`.""" |
539
|
|
|
return nib_to_sitk(self.data, self.affine, **kwargs) |
540
|
|
|
|
541
|
|
|
@classmethod |
542
|
|
|
def from_sitk(cls, sitk_image): |
543
|
|
|
"""Instantiate a new TorchIO image from a :class:`sitk.Image`. |
544
|
|
|
|
545
|
|
|
Example: |
546
|
|
|
>>> import torchio as tio |
547
|
|
|
>>> import SimpleITK as sitk |
548
|
|
|
>>> sitk_image = sitk.Image(20, 30, 40, sitk.sitkUInt16) |
549
|
|
|
>>> tio.LabelMap.from_sitk(sitk_image) |
550
|
|
|
LabelMap(shape: (1, 20, 30, 40); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 93.8 KiB; dtype: torch.IntTensor) |
551
|
|
|
>>> sitk_image = sitk.Image((224, 224), sitk.sitkVectorFloat32, 3) |
552
|
|
|
>>> tio.ScalarImage.from_sitk(sitk_image) |
553
|
|
|
ScalarImage(shape: (3, 224, 224, 1); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 588.0 KiB; dtype: torch.FloatTensor) |
554
|
|
|
""" |
555
|
|
|
tensor, affine = sitk_to_nib(sitk_image) |
556
|
|
|
return cls(tensor=tensor, affine=affine) |
557
|
|
|
|
558
|
|
|
def as_pil(self, transpose=True): |
559
|
|
|
"""Get the image as an instance of :class:`PIL.Image`. |
560
|
|
|
|
561
|
|
|
.. note:: Values will be clamped to 0-255 and cast to uint8. |
562
|
|
|
.. note:: To use this method, `Pillow` needs to be installed: |
563
|
|
|
`pip install Pillow`. |
564
|
|
|
""" |
565
|
|
|
try: |
566
|
|
|
from PIL import Image as ImagePIL |
567
|
|
|
except ModuleNotFoundError as e: |
568
|
|
|
message = ( |
569
|
|
|
'Please install Pillow to use Image.as_pil():' |
570
|
|
|
' pip install Pillow' |
571
|
|
|
) |
572
|
|
|
raise RuntimeError(message) from e |
573
|
|
|
|
574
|
|
|
self.check_is_2d() |
575
|
|
|
tensor = self.data |
576
|
|
|
if len(tensor) == 1: |
577
|
|
|
tensor = torch.cat(3 * [tensor]) |
578
|
|
|
if len(tensor) != 3: |
579
|
|
|
raise RuntimeError('The image must have 1 or 3 channels') |
580
|
|
|
if transpose: |
581
|
|
|
tensor = tensor.permute(3, 2, 1, 0) |
582
|
|
|
else: |
583
|
|
|
tensor = tensor.permute(3, 1, 2, 0) |
584
|
|
|
array = tensor.clamp(0, 255).numpy()[0] |
585
|
|
|
return ImagePIL.fromarray(array.astype(np.uint8)) |
586
|
|
|
|
587
|
|
|
def get_center(self, lps: bool = False) -> TypeTripletFloat: |
588
|
|
|
"""Get image center in RAS+ or LPS+ coordinates. |
589
|
|
|
|
590
|
|
|
Args: |
591
|
|
|
lps: If ``True``, the coordinates will be in LPS+ orientation, i.e. |
592
|
|
|
the first dimension grows towards the left, etc. Otherwise, the |
593
|
|
|
coordinates will be in RAS+ orientation. |
594
|
|
|
""" |
595
|
|
|
size = np.array(self.spatial_shape) |
596
|
|
|
center_index = (size - 1) / 2 |
597
|
|
|
r, a, s = nib.affines.apply_affine(self.affine, center_index) |
598
|
|
|
if lps: |
599
|
|
|
return (-r, -a, s) |
600
|
|
|
else: |
601
|
|
|
return (r, a, s) |
602
|
|
|
|
603
|
|
|
def set_check_nans(self, check_nans: bool) -> None: |
604
|
|
|
self.check_nans = check_nans |
605
|
|
|
|
606
|
|
|
def plot(self, **kwargs) -> None: |
607
|
|
|
"""Plot image.""" |
608
|
|
|
if self.is_2d(): |
609
|
|
|
self.as_pil().show() |
610
|
|
|
else: |
611
|
|
|
from ..visualization import plot_volume # avoid circular import |
612
|
|
|
plot_volume(self, **kwargs) |
613
|
|
|
|
614
|
|
|
|
615
|
|
|
class ScalarImage(Image): |
616
|
|
|
"""Image whose pixel values represent scalars. |
617
|
|
|
|
618
|
|
|
Example: |
619
|
|
|
>>> import torch |
620
|
|
|
>>> import torchio as tio |
621
|
|
|
>>> # Loading from a file |
622
|
|
|
>>> t1_image = tio.ScalarImage('t1.nii.gz') |
623
|
|
|
>>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88)) |
624
|
|
|
>>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False) |
625
|
|
|
>>> data, affine = image.data, image.affine |
626
|
|
|
>>> affine.shape |
627
|
|
|
(4, 4) |
628
|
|
|
>>> image.data is image[tio.DATA] |
629
|
|
|
True |
630
|
|
|
>>> image.data is image.tensor |
631
|
|
|
True |
632
|
|
|
>>> type(image.data) |
633
|
|
|
torch.Tensor |
634
|
|
|
|
635
|
|
|
See :class:`~torchio.Image` for more information. |
636
|
|
|
""" |
637
|
|
|
def __init__(self, *args, **kwargs): |
638
|
|
|
if 'type' in kwargs and kwargs['type'] != INTENSITY: |
639
|
|
|
raise ValueError('Type of ScalarImage is always torchio.INTENSITY') |
640
|
|
|
kwargs.update({'type': INTENSITY}) |
641
|
|
|
super().__init__(*args, **kwargs) |
642
|
|
|
|
643
|
|
|
|
644
|
|
|
class LabelMap(Image): |
645
|
|
|
"""Image whose pixel values represent categorical labels. |
646
|
|
|
|
647
|
|
|
Example: |
648
|
|
|
>>> import torch |
649
|
|
|
>>> import torchio as tio |
650
|
|
|
>>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5) |
651
|
|
|
>>> labels = tio.LabelMap('t1_seg.nii.gz') # loading from a file |
652
|
|
|
>>> tpm = tio.LabelMap( # loading from files |
653
|
|
|
... 'gray_matter.nii.gz', |
654
|
|
|
... 'white_matter.nii.gz', |
655
|
|
|
... 'csf.nii.gz', |
656
|
|
|
... ) |
657
|
|
|
|
658
|
|
|
Intensity transforms are not applied to these images. |
659
|
|
|
|
660
|
|
|
Nearest neighbor interpolation is always used to resample label maps, |
661
|
|
|
independently of the specified interpolation type in the transform |
662
|
|
|
instantiation. |
663
|
|
|
|
664
|
|
|
See :class:`~torchio.Image` for more information. |
665
|
|
|
""" |
666
|
|
|
def __init__(self, *args, **kwargs): |
667
|
|
|
if 'type' in kwargs and kwargs['type'] != LABEL: |
668
|
|
|
raise ValueError('Type of LabelMap is always torchio.LABEL') |
669
|
|
|
kwargs.update({'type': LABEL}) |
670
|
|
|
super().__init__(*args, **kwargs) |
671
|
|
|
|