1
|
|
|
import warnings |
2
|
|
|
from pathlib import Path |
3
|
|
|
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List |
4
|
|
|
|
5
|
|
|
import torch |
6
|
|
|
import humanize |
7
|
|
|
import numpy as np |
8
|
|
|
import nibabel as nib |
9
|
|
|
import SimpleITK as sitk |
10
|
|
|
from deprecated import deprecated |
11
|
|
|
|
12
|
|
|
from ..utils import get_stem |
13
|
|
|
from ..typing import TypeData, TypePath, TypeTripletInt, TypeTripletFloat |
14
|
|
|
from ..constants import DATA, TYPE, AFFINE, PATH, STEM, INTENSITY, LABEL |
15
|
|
|
from .io import ( |
16
|
|
|
ensure_4d, |
17
|
|
|
read_image, |
18
|
|
|
write_image, |
19
|
|
|
nib_to_sitk, |
20
|
|
|
check_uint_to_int, |
21
|
|
|
get_rotation_and_spacing_from_affine, |
22
|
|
|
) |
23
|
|
|
|
24
|
|
|
|
25
|
|
|
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM |
26
|
|
|
TypeBound = Tuple[float, float] |
27
|
|
|
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound] |
28
|
|
|
|
29
|
|
|
deprecation_message = ( |
30
|
|
|
'Setting the image data with the property setter is deprecated. Use the' |
31
|
|
|
' set_data() method instead' |
32
|
|
|
) |
33
|
|
|
|
34
|
|
|
|
35
|
|
|
class Image(dict): |
36
|
|
|
r"""TorchIO image. |
37
|
|
|
|
38
|
|
|
For information about medical image orientation, check out `NiBabel docs`_, |
39
|
|
|
the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or |
40
|
|
|
`SimpleITK docs`_. |
41
|
|
|
|
42
|
|
|
Args: |
43
|
|
|
path: Path to a file or sequence of paths to files that can be read by |
44
|
|
|
:mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing |
45
|
|
|
DICOM files. If :attr:`tensor` is given, the data in |
46
|
|
|
:attr:`path` will not be read. |
47
|
|
|
If a sequence of paths is given, data |
48
|
|
|
will be concatenated on the channel dimension so spatial |
49
|
|
|
dimensions must match. |
50
|
|
|
type: Type of image, such as :attr:`torchio.INTENSITY` or |
51
|
|
|
:attr:`torchio.LABEL`. This will be used by the transforms to |
52
|
|
|
decide whether to apply an operation, or which interpolation to use |
53
|
|
|
when resampling. For example, `preprocessing`_ and `augmentation`_ |
54
|
|
|
intensity transforms will only be applied to images with type |
55
|
|
|
:attr:`torchio.INTENSITY`. Spatial transforms will be applied to |
56
|
|
|
all types, and nearest neighbor interpolation is always used to |
57
|
|
|
resample images with type :attr:`torchio.LABEL`. |
58
|
|
|
The type :attr:`torchio.SAMPLING_MAP` may be used with instances of |
59
|
|
|
:class:`~torchio.data.sampler.weighted.WeightedSampler`. |
60
|
|
|
tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D |
61
|
|
|
:class:`torch.Tensor` or NumPy array with dimensions |
62
|
|
|
:math:`(C, W, H, D)`. |
63
|
|
|
affine: If :attr:`path` is not given, :attr:`affine` must be a |
64
|
|
|
:math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an |
65
|
|
|
identity matrix. |
66
|
|
|
check_nans: If ``True``, issues a warning if NaNs are found |
67
|
|
|
in the image. If ``False``, images will not be checked for the |
68
|
|
|
presence of NaNs. |
69
|
|
|
**kwargs: Items that will be added to the image dictionary, e.g. |
70
|
|
|
acquisition parameters. |
71
|
|
|
|
72
|
|
|
TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk |
73
|
|
|
when needed. |
74
|
|
|
|
75
|
|
|
Example: |
76
|
|
|
>>> import torchio as tio |
77
|
|
|
>>> image = tio.ScalarImage('t1.nii.gz') # subclass of Image |
78
|
|
|
>>> image # not loaded yet |
79
|
|
|
ScalarImage(path: t1.nii.gz; type: intensity) |
80
|
|
|
>>> times_two = 2 * image.data # data is loaded and cached here |
81
|
|
|
>>> image |
82
|
|
|
ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity) |
83
|
|
|
>>> image.save('doubled_image.nii.gz') |
84
|
|
|
|
85
|
|
|
.. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading |
86
|
|
|
.. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity |
87
|
|
|
.. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity |
88
|
|
|
.. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html |
89
|
|
|
.. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems |
90
|
|
|
.. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained |
91
|
|
|
.. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html |
92
|
|
|
.. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm |
93
|
|
|
""" |
94
|
|
|
def __init__( |
95
|
|
|
self, |
96
|
|
|
path: Union[TypePath, Sequence[TypePath], None] = None, |
97
|
|
|
type: str = None, |
98
|
|
|
tensor: Optional[TypeData] = None, |
99
|
|
|
affine: Optional[TypeData] = None, |
100
|
|
|
check_nans: bool = False, # removed by ITK by default |
101
|
|
|
channels_last: bool = False, |
102
|
|
|
**kwargs: Dict[str, Any], |
103
|
|
|
): |
104
|
|
|
self.check_nans = check_nans |
105
|
|
|
self.channels_last = channels_last |
106
|
|
|
|
107
|
|
|
if type is None: |
108
|
|
|
warnings.warn( |
109
|
|
|
'Not specifying the image type is deprecated and will be' |
110
|
|
|
' mandatory in the future. You can probably use tio.ScalarImage' |
111
|
|
|
' or tio.LabelMap instead', DeprecationWarning, |
112
|
|
|
) |
113
|
|
|
type = INTENSITY |
114
|
|
|
|
115
|
|
|
if path is None and tensor is None: |
116
|
|
|
raise ValueError('A value for path or tensor must be given') |
117
|
|
|
self._loaded = False |
118
|
|
|
|
119
|
|
|
tensor = self._parse_tensor(tensor) |
120
|
|
|
affine = self._parse_affine(affine) |
121
|
|
|
if tensor is not None: |
122
|
|
|
self.set_data(tensor) |
123
|
|
|
self.affine = affine |
124
|
|
|
self._loaded = True |
125
|
|
|
for key in PROTECTED_KEYS: |
126
|
|
|
if key in kwargs: |
127
|
|
|
message = f'Key "{key}" is reserved. Use a different one' |
128
|
|
|
raise ValueError(message) |
129
|
|
|
|
130
|
|
|
super().__init__(**kwargs) |
131
|
|
|
self.path = self._parse_path(path) |
132
|
|
|
|
133
|
|
|
self[PATH] = '' if self.path is None else str(self.path) |
134
|
|
|
self[STEM] = '' if self.path is None else get_stem(self.path) |
135
|
|
|
self[TYPE] = type |
136
|
|
|
|
137
|
|
|
def __repr__(self): |
138
|
|
|
properties = [] |
139
|
|
|
if self._loaded: |
140
|
|
|
properties.extend([ |
141
|
|
|
f'shape: {self.shape}', |
142
|
|
|
f'spacing: {self.get_spacing_string()}', |
143
|
|
|
f'orientation: {"".join(self.orientation)}+', |
144
|
|
|
f'memory: {humanize.naturalsize(self.memory, binary=True)}', |
145
|
|
|
]) |
146
|
|
|
else: |
147
|
|
|
properties.append(f'path: "{self.path}"') |
148
|
|
|
if self._loaded: |
149
|
|
|
properties.append(f'dtype: {self.data.type()}') |
150
|
|
|
properties = '; '.join(properties) |
151
|
|
|
string = f'{self.__class__.__name__}({properties})' |
152
|
|
|
return string |
153
|
|
|
|
154
|
|
|
def __getitem__(self, item): |
155
|
|
|
if item in (DATA, AFFINE): |
156
|
|
|
if item not in self: |
157
|
|
|
self.load() |
158
|
|
|
return super().__getitem__(item) |
159
|
|
|
|
160
|
|
|
def __array__(self): |
161
|
|
|
return self.data.numpy() |
162
|
|
|
|
163
|
|
|
def __copy__(self): |
164
|
|
|
kwargs = dict( |
165
|
|
|
tensor=self.data, |
166
|
|
|
affine=self.affine, |
167
|
|
|
type=self.type, |
168
|
|
|
path=self.path, |
169
|
|
|
) |
170
|
|
|
for key, value in self.items(): |
171
|
|
|
if key in PROTECTED_KEYS: continue |
172
|
|
|
kwargs[key] = value # should I copy? deepcopy? |
173
|
|
|
return self.__class__(**kwargs) |
174
|
|
|
|
175
|
|
|
@property |
176
|
|
|
def data(self) -> torch.Tensor: |
177
|
|
|
"""Tensor data. Same as :class:`Image.tensor`.""" |
178
|
|
|
return self[DATA] |
179
|
|
|
|
180
|
|
|
@data.setter |
181
|
|
|
@deprecated(version='0.18.16', reason=deprecation_message) |
182
|
|
|
def data(self, tensor: TypeData): |
183
|
|
|
self.set_data(tensor) |
184
|
|
|
|
185
|
|
|
def set_data(self, tensor: TypeData): |
186
|
|
|
"""Store a 4D tensor in the :attr:`data` key and attribute. |
187
|
|
|
|
188
|
|
|
Args: |
189
|
|
|
tensor: 4D tensor with dimensions :math:`(C, W, H, D)`. |
190
|
|
|
""" |
191
|
|
|
self[DATA] = self._parse_tensor(tensor, none_ok=False) |
192
|
|
|
|
193
|
|
|
@property |
194
|
|
|
def tensor(self) -> torch.Tensor: |
195
|
|
|
"""Tensor data. Same as :class:`Image.data`.""" |
196
|
|
|
return self.data |
197
|
|
|
|
198
|
|
|
@property |
199
|
|
|
def affine(self) -> np.ndarray: |
200
|
|
|
"""Affine matrix to transform voxel indices into world coordinates.""" |
201
|
|
|
return self[AFFINE] |
202
|
|
|
|
203
|
|
|
@affine.setter |
204
|
|
|
def affine(self, matrix): |
205
|
|
|
self[AFFINE] = self._parse_affine(matrix) |
206
|
|
|
|
207
|
|
|
@property |
208
|
|
|
def type(self) -> str: |
209
|
|
|
return self[TYPE] |
210
|
|
|
|
211
|
|
|
@property |
212
|
|
|
def shape(self) -> Tuple[int, int, int, int]: |
213
|
|
|
"""Tensor shape as :math:`(C, W, H, D)`.""" |
214
|
|
|
return tuple(self.data.shape) |
215
|
|
|
|
216
|
|
|
@property |
217
|
|
|
def spatial_shape(self) -> TypeTripletInt: |
218
|
|
|
"""Tensor spatial shape as :math:`(W, H, D)`.""" |
219
|
|
|
return self.shape[1:] |
220
|
|
|
|
221
|
|
|
def check_is_2d(self) -> None: |
222
|
|
|
if not self.is_2d(): |
223
|
|
|
message = f'Image is not 2D. Spatial shape: {self.spatial_shape}' |
224
|
|
|
raise RuntimeError(message) |
225
|
|
|
|
226
|
|
|
@property |
227
|
|
|
def height(self) -> int: |
228
|
|
|
"""Image height, if 2D.""" |
229
|
|
|
self.check_is_2d() |
230
|
|
|
return self.spatial_shape[1] |
231
|
|
|
|
232
|
|
|
@property |
233
|
|
|
def width(self) -> int: |
234
|
|
|
"""Image width, if 2D.""" |
235
|
|
|
self.check_is_2d() |
236
|
|
|
return self.spatial_shape[0] |
237
|
|
|
|
238
|
|
|
@property |
239
|
|
|
def orientation(self) -> Tuple[str, str, str]: |
240
|
|
|
"""Orientation codes.""" |
241
|
|
|
return nib.aff2axcodes(self.affine) |
242
|
|
|
|
243
|
|
|
@property |
244
|
|
|
def spacing(self) -> Tuple[float, float, float]: |
245
|
|
|
"""Voxel spacing in mm.""" |
246
|
|
|
_, spacing = get_rotation_and_spacing_from_affine(self.affine) |
247
|
|
|
return tuple(spacing) |
248
|
|
|
|
249
|
|
|
@property |
250
|
|
|
def itemsize(self): |
251
|
|
|
"""Element size of the data type.""" |
252
|
|
|
return self.data.element_size() |
253
|
|
|
|
254
|
|
|
@property |
255
|
|
|
def memory(self) -> float: |
256
|
|
|
"""Number of Bytes that the tensor takes in the RAM.""" |
257
|
|
|
return np.prod(self.shape) * self.itemsize |
258
|
|
|
|
259
|
|
|
@property |
260
|
|
|
def bounds(self) -> np.ndarray: |
261
|
|
|
"""Position of centers of voxels in smallest and largest coordinates.""" |
262
|
|
|
ini = 0, 0, 0 |
263
|
|
|
fin = np.array(self.spatial_shape) - 1 |
264
|
|
|
point_ini = nib.affines.apply_affine(self.affine, ini) |
265
|
|
|
point_fin = nib.affines.apply_affine(self.affine, fin) |
266
|
|
|
return np.array((point_ini, point_fin)) |
267
|
|
|
|
268
|
|
|
def axis_name_to_index(self, axis: str) -> int: |
269
|
|
|
"""Convert an axis name to an axis index. |
270
|
|
|
|
271
|
|
|
Args: |
272
|
|
|
axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``, |
273
|
|
|
``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case |
274
|
|
|
versions and first letters are also valid, as only the first |
275
|
|
|
letter will be used. |
276
|
|
|
|
277
|
|
|
.. note:: If you are working with animals, you should probably use |
278
|
|
|
``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'`` |
279
|
|
|
for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``, |
280
|
|
|
respectively. |
281
|
|
|
|
282
|
|
|
.. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``, |
283
|
|
|
``'Left'`` and ``'Right'``. |
284
|
|
|
""" |
285
|
|
|
# Top and bottom are used for the vertical 2D axis as the use of |
286
|
|
|
# Height vs Horizontal might be ambiguous |
287
|
|
|
|
288
|
|
|
if not isinstance(axis, str): |
289
|
|
|
raise ValueError('Axis must be a string') |
290
|
|
|
axis = axis[0].upper() |
291
|
|
|
|
292
|
|
|
# Generally, TorchIO tensors are (C, W, H, D) |
293
|
|
|
if axis in 'TB': # Top, Bottom |
294
|
|
|
return -2 |
295
|
|
|
else: |
296
|
|
|
try: |
297
|
|
|
index = self.orientation.index(axis) |
298
|
|
|
except ValueError: |
299
|
|
|
index = self.orientation.index(self.flip_axis(axis)) |
300
|
|
|
# Return negative indices so that it does not matter whether we |
301
|
|
|
# refer to spatial dimensions or not |
302
|
|
|
index = -3 + index |
303
|
|
|
return index |
304
|
|
|
|
305
|
|
|
# flake8: noqa: E701 |
306
|
|
|
@staticmethod |
307
|
|
|
def flip_axis(axis: str) -> str: |
308
|
|
|
if axis == 'R': flipped_axis = 'L' |
309
|
|
|
elif axis == 'L': flipped_axis = 'R' |
310
|
|
|
elif axis == 'A': flipped_axis = 'P' |
311
|
|
|
elif axis == 'P': flipped_axis = 'A' |
312
|
|
|
elif axis == 'I': flipped_axis = 'S' |
313
|
|
|
elif axis == 'S': flipped_axis = 'I' |
314
|
|
|
elif axis == 'T': flipped_axis = 'B' |
315
|
|
|
elif axis == 'B': flipped_axis = 'T' |
316
|
|
|
else: |
317
|
|
|
values = ', '.join('LRPAISTB') |
318
|
|
|
message = f'Axis not understood. Please use one of: {values}' |
319
|
|
|
raise ValueError(message) |
320
|
|
|
return flipped_axis |
321
|
|
|
|
322
|
|
|
def get_spacing_string(self) -> str: |
323
|
|
|
strings = [f'{n:.2f}' for n in self.spacing] |
324
|
|
|
string = f'({", ".join(strings)})' |
325
|
|
|
return string |
326
|
|
|
|
327
|
|
|
def get_bounds(self) -> TypeBounds: |
328
|
|
|
"""Get minimum and maximum world coordinates occupied by the image.""" |
329
|
|
|
first_index = 3 * (-0.5,) |
330
|
|
|
last_index = np.array(self.spatial_shape) - 0.5 |
331
|
|
|
first_point = nib.affines.apply_affine(self.affine, first_index) |
332
|
|
|
last_point = nib.affines.apply_affine(self.affine, last_index) |
333
|
|
|
array = np.array((first_point, last_point)) |
334
|
|
|
bounds_x, bounds_y, bounds_z = array.T.tolist() |
335
|
|
|
return bounds_x, bounds_y, bounds_z |
336
|
|
|
|
337
|
|
|
@staticmethod |
338
|
|
|
def _parse_single_path( |
339
|
|
|
path: TypePath |
340
|
|
|
) -> Path: |
341
|
|
|
try: |
342
|
|
|
path = Path(path).expanduser() |
343
|
|
|
except TypeError: |
344
|
|
|
message = ( |
345
|
|
|
f'Expected type str or Path but found {path} with type' |
346
|
|
|
f' {type(path)} instead' |
347
|
|
|
) |
348
|
|
|
raise TypeError(message) |
349
|
|
|
except RuntimeError: |
350
|
|
|
message = ( |
351
|
|
|
f'Conversion to path not possible for variable: {path}' |
352
|
|
|
) |
353
|
|
|
raise RuntimeError(message) |
354
|
|
|
|
355
|
|
|
if not (path.is_file() or path.is_dir()): # might be a dir with DICOM |
356
|
|
|
raise FileNotFoundError(f'File not found: "{path}"') |
357
|
|
|
return path |
358
|
|
|
|
359
|
|
|
def _parse_path( |
360
|
|
|
self, |
361
|
|
|
path: Union[TypePath, Sequence[TypePath]] |
362
|
|
|
) -> Optional[Union[Path, List[Path]]]: |
363
|
|
|
if path is None: |
364
|
|
|
return None |
365
|
|
|
if isinstance(path, (str, Path)): |
366
|
|
|
return self._parse_single_path(path) |
367
|
|
|
else: |
368
|
|
|
return [self._parse_single_path(p) for p in path] |
369
|
|
|
|
370
|
|
|
def _parse_tensor( |
371
|
|
|
self, |
372
|
|
|
tensor: TypeData, |
373
|
|
|
none_ok: bool = True, |
374
|
|
|
) -> torch.Tensor: |
375
|
|
|
if tensor is None: |
376
|
|
|
if none_ok: |
377
|
|
|
return None |
378
|
|
|
else: |
379
|
|
|
raise RuntimeError('Input tensor cannot be None') |
380
|
|
|
if isinstance(tensor, np.ndarray): |
381
|
|
|
tensor = check_uint_to_int(tensor) |
382
|
|
|
tensor = torch.from_numpy(tensor) |
383
|
|
|
elif not isinstance(tensor, torch.Tensor): |
384
|
|
|
message = 'Input tensor must be a PyTorch tensor or NumPy array' |
385
|
|
|
raise TypeError(message) |
386
|
|
|
if tensor.ndim != 4: |
387
|
|
|
raise ValueError('Input tensor must be 4D') |
388
|
|
|
if tensor.dtype == torch.bool: |
389
|
|
|
tensor = tensor.to(torch.uint8) |
390
|
|
|
if self.check_nans and torch.isnan(tensor).any(): |
391
|
|
|
warnings.warn(f'NaNs found in tensor', RuntimeWarning) |
392
|
|
|
return tensor |
393
|
|
|
|
394
|
|
|
def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor: |
395
|
|
|
return ensure_4d(tensor) |
396
|
|
|
|
397
|
|
|
@staticmethod |
398
|
|
|
def _parse_affine(affine: TypeData) -> np.ndarray: |
399
|
|
|
if affine is None: |
400
|
|
|
return np.eye(4) |
401
|
|
|
if isinstance(affine, torch.Tensor): |
402
|
|
|
affine = affine.numpy() |
403
|
|
|
if not isinstance(affine, np.ndarray): |
404
|
|
|
raise TypeError(f'Affine must be a NumPy array, not {type(affine)}') |
405
|
|
|
if affine.shape != (4, 4): |
406
|
|
|
raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}') |
407
|
|
|
return affine |
408
|
|
|
|
409
|
|
|
def load(self) -> None: |
410
|
|
|
r"""Load the image from disk. |
411
|
|
|
|
412
|
|
|
Returns: |
413
|
|
|
Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D |
414
|
|
|
:math:`4 \times 4` affine matrix to convert voxel indices to world |
415
|
|
|
coordinates. |
416
|
|
|
""" |
417
|
|
|
if self._loaded: |
418
|
|
|
return |
419
|
|
|
paths = self.path if isinstance(self.path, list) else [self.path] |
420
|
|
|
tensor, affine = self.read_and_check(paths[0]) |
421
|
|
|
tensors = [tensor] |
422
|
|
|
for path in paths[1:]: |
423
|
|
|
new_tensor, new_affine = self.read_and_check(path) |
424
|
|
|
if not np.array_equal(affine, new_affine): |
425
|
|
|
message = ( |
426
|
|
|
'Files have different affine matrices.' |
427
|
|
|
f'\nMatrix of {paths[0]}:' |
428
|
|
|
f'\n{affine}' |
429
|
|
|
f'\nMatrix of {path}:' |
430
|
|
|
f'\n{new_affine}' |
431
|
|
|
) |
432
|
|
|
warnings.warn(message, RuntimeWarning) |
433
|
|
|
if not tensor.shape[1:] == new_tensor.shape[1:]: |
434
|
|
|
message = ( |
435
|
|
|
f'Files shape do not match, found {tensor.shape}' |
436
|
|
|
f'and {new_tensor.shape}' |
437
|
|
|
) |
438
|
|
|
RuntimeError(message) |
439
|
|
|
tensors.append(new_tensor) |
440
|
|
|
tensor = torch.cat(tensors) |
441
|
|
|
self.set_data(tensor) |
442
|
|
|
self.affine = affine |
443
|
|
|
self._loaded = True |
444
|
|
|
|
445
|
|
|
def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]: |
446
|
|
|
tensor, affine = read_image(path) |
447
|
|
|
tensor = self.parse_tensor_shape(tensor) |
448
|
|
|
if self.channels_last: |
449
|
|
|
tensor = tensor.permute(3, 0, 1, 2) |
450
|
|
|
if self.check_nans and torch.isnan(tensor).any(): |
451
|
|
|
warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning) |
452
|
|
|
return tensor, affine |
453
|
|
|
|
454
|
|
|
def save(self, path: TypePath, squeeze: bool = True) -> None: |
455
|
|
|
"""Save image to disk. |
456
|
|
|
|
457
|
|
|
Args: |
458
|
|
|
path: String or instance of :class:`pathlib.Path`. |
459
|
|
|
squeeze: If ``True``, singleton dimensions will be removed |
460
|
|
|
before saving. |
461
|
|
|
""" |
462
|
|
|
write_image( |
463
|
|
|
self.data, |
464
|
|
|
self.affine, |
465
|
|
|
path, |
466
|
|
|
squeeze=squeeze, |
467
|
|
|
) |
468
|
|
|
|
469
|
|
|
def is_2d(self) -> bool: |
470
|
|
|
return self.shape[-1] == 1 |
471
|
|
|
|
472
|
|
|
def numpy(self) -> np.ndarray: |
473
|
|
|
"""Get a NumPy array containing the image data.""" |
474
|
|
|
return np.asarray(self) |
475
|
|
|
|
476
|
|
|
def as_sitk(self, **kwargs) -> sitk.Image: |
477
|
|
|
"""Get the image as an instance of :class:`sitk.Image`.""" |
478
|
|
|
return nib_to_sitk(self.data, self.affine, **kwargs) |
479
|
|
|
|
480
|
|
|
def as_pil(self): |
481
|
|
|
"""Get the image as an instance of :class:`PIL.Image`. |
482
|
|
|
|
483
|
|
|
.. note:: Values will be clamped to 0-255 and cast to uint8. |
484
|
|
|
.. note:: To use this method, `Pillow` needs to be installed: |
485
|
|
|
`pip install Pillow`. |
486
|
|
|
""" |
487
|
|
|
try: |
488
|
|
|
from PIL import Image as ImagePIL |
489
|
|
|
except ModuleNotFoundError as e: |
490
|
|
|
message = ( |
491
|
|
|
'Please install Pillow to use Image.as_pil():' |
492
|
|
|
' pip install Pillow' |
493
|
|
|
) |
494
|
|
|
raise RuntimeError(message) from e |
495
|
|
|
|
496
|
|
|
self.check_is_2d() |
497
|
|
|
tensor = self.data |
498
|
|
|
if len(tensor) == 1: |
499
|
|
|
tensor = torch.cat(3 * [tensor]) |
500
|
|
|
if len(tensor) != 3: |
501
|
|
|
raise RuntimeError('The image must have 1 or 3 channels') |
502
|
|
|
tensor = tensor.permute(3, 1, 2, 0)[0] |
503
|
|
|
array = tensor.clamp(0, 255).numpy() |
504
|
|
|
return ImagePIL.fromarray(array.astype(np.uint8)) |
505
|
|
|
|
506
|
|
|
def get_center(self, lps: bool = False) -> TypeTripletFloat: |
507
|
|
|
"""Get image center in RAS+ or LPS+ coordinates. |
508
|
|
|
|
509
|
|
|
Args: |
510
|
|
|
lps: If ``True``, the coordinates will be in LPS+ orientation, i.e. |
511
|
|
|
the first dimension grows towards the left, etc. Otherwise, the |
512
|
|
|
coordinates will be in RAS+ orientation. |
513
|
|
|
""" |
514
|
|
|
size = np.array(self.spatial_shape) |
515
|
|
|
center_index = (size - 1) / 2 |
516
|
|
|
r, a, s = nib.affines.apply_affine(self.affine, center_index) |
517
|
|
|
if lps: |
518
|
|
|
return (-r, -a, s) |
519
|
|
|
else: |
520
|
|
|
return (r, a, s) |
521
|
|
|
|
522
|
|
|
def set_check_nans(self, check_nans: bool) -> None: |
523
|
|
|
self.check_nans = check_nans |
524
|
|
|
|
525
|
|
|
def plot(self, **kwargs) -> None: |
526
|
|
|
"""Plot image.""" |
527
|
|
|
if self.is_2d(): |
528
|
|
|
self.as_pil().show() |
529
|
|
|
else: |
530
|
|
|
from ..visualization import plot_volume # avoid circular import |
531
|
|
|
plot_volume(self, **kwargs) |
532
|
|
|
|
533
|
|
|
|
534
|
|
|
class ScalarImage(Image): |
535
|
|
|
"""Image whose pixel values represent scalars. |
536
|
|
|
|
537
|
|
|
Example: |
538
|
|
|
>>> import torch |
539
|
|
|
>>> import torchio as tio |
540
|
|
|
>>> # Loading from a file |
541
|
|
|
>>> t1_image = tio.ScalarImage('t1.nii.gz') |
542
|
|
|
>>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88)) |
543
|
|
|
>>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False) |
544
|
|
|
>>> data, affine = image.data, image.affine |
545
|
|
|
>>> affine.shape |
546
|
|
|
(4, 4) |
547
|
|
|
>>> image.data is image[tio.DATA] |
548
|
|
|
True |
549
|
|
|
>>> image.data is image.tensor |
550
|
|
|
True |
551
|
|
|
>>> type(image.data) |
552
|
|
|
torch.Tensor |
553
|
|
|
|
554
|
|
|
See :class:`~torchio.Image` for more information. |
555
|
|
|
""" |
556
|
|
|
def __init__(self, *args, **kwargs): |
557
|
|
|
if 'type' in kwargs and kwargs['type'] != INTENSITY: |
558
|
|
|
raise ValueError('Type of ScalarImage is always torchio.INTENSITY') |
559
|
|
|
kwargs.update({'type': INTENSITY}) |
560
|
|
|
super().__init__(*args, **kwargs) |
561
|
|
|
|
562
|
|
|
|
563
|
|
|
class LabelMap(Image): |
564
|
|
|
"""Image whose pixel values represent categorical labels. |
565
|
|
|
|
566
|
|
|
Intensity transforms are not applied to these images. |
567
|
|
|
|
568
|
|
|
Example: |
569
|
|
|
>>> import torch |
570
|
|
|
>>> import torchio as tio |
571
|
|
|
>>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5) |
572
|
|
|
>>> labels = tio.LabelMap('t1_seg.nii.gz') # loading from a file |
573
|
|
|
>>> tpm = tio.LabelMap( # loading from files |
574
|
|
|
... 'gray_matter.nii.gz', |
575
|
|
|
... 'white_matter.nii.gz', |
576
|
|
|
... 'csf.nii.gz', |
577
|
|
|
... ) |
578
|
|
|
|
579
|
|
|
See :class:`~torchio.Image` for more information. |
580
|
|
|
""" |
581
|
|
|
def __init__(self, *args, **kwargs): |
582
|
|
|
if 'type' in kwargs and kwargs['type'] != LABEL: |
583
|
|
|
raise ValueError('Type of LabelMap is always torchio.LABEL') |
584
|
|
|
kwargs.update({'type': LABEL}) |
585
|
|
|
super().__init__(*args, **kwargs) |
586
|
|
|
|