1
|
|
|
from __future__ import annotations |
2
|
|
|
|
3
|
|
|
import warnings |
4
|
|
|
from collections import Counter |
5
|
|
|
from collections.abc import Sequence |
6
|
|
|
from pathlib import Path |
7
|
|
|
from typing import Any |
8
|
|
|
from typing import Callable |
9
|
|
|
|
10
|
|
|
import humanize |
11
|
|
|
import nibabel as nib |
12
|
|
|
import numpy as np |
13
|
|
|
import SimpleITK as sitk |
14
|
|
|
import torch |
15
|
|
|
from deprecated import deprecated |
16
|
|
|
from nibabel.affines import apply_affine |
17
|
|
|
|
18
|
|
|
from ..constants import AFFINE |
19
|
|
|
from ..constants import DATA |
20
|
|
|
from ..constants import INTENSITY |
21
|
|
|
from ..constants import LABEL |
22
|
|
|
from ..constants import PATH |
23
|
|
|
from ..constants import STEM |
24
|
|
|
from ..constants import TENSOR |
25
|
|
|
from ..constants import TYPE |
26
|
|
|
from ..types import TypeData |
27
|
|
|
from ..types import TypeDataAffine |
28
|
|
|
from ..types import TypeDirection3D |
29
|
|
|
from ..types import TypePath |
30
|
|
|
from ..types import TypeQuartetInt |
31
|
|
|
from ..types import TypeSlice |
32
|
|
|
from ..types import TypeTripletFloat |
33
|
|
|
from ..types import TypeTripletInt |
34
|
|
|
from ..utils import get_stem |
35
|
|
|
from ..utils import guess_external_viewer |
36
|
|
|
from ..utils import is_iterable |
37
|
|
|
from ..utils import to_tuple |
38
|
|
|
from .io import check_uint_to_int |
39
|
|
|
from .io import ensure_4d |
40
|
|
|
from .io import get_rotation_and_spacing_from_affine |
41
|
|
|
from .io import get_sitk_metadata_from_ras_affine |
42
|
|
|
from .io import nib_to_sitk |
43
|
|
|
from .io import read_affine |
44
|
|
|
from .io import read_image |
45
|
|
|
from .io import read_shape |
46
|
|
|
from .io import sitk_to_nib |
47
|
|
|
from .io import write_image |
48
|
|
|
|
49
|
|
|
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM |
50
|
|
|
TypeBound = tuple[float, float] |
51
|
|
|
TypeBounds = tuple[TypeBound, TypeBound, TypeBound] |
52
|
|
|
|
53
|
|
|
deprecation_message = ( |
54
|
|
|
'Setting the image data with the property setter is deprecated. Use the' |
55
|
|
|
' set_data() method instead' |
56
|
|
|
) |
57
|
|
|
|
58
|
|
|
|
59
|
|
|
class Image(dict): |
60
|
|
|
r"""TorchIO image. |
61
|
|
|
|
62
|
|
|
For information about medical image orientation, check out `NiBabel docs`_, |
63
|
|
|
the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or |
64
|
|
|
`SimpleITK docs`_. |
65
|
|
|
|
66
|
|
|
Args: |
67
|
|
|
path: Path to a file or sequence of paths to files that can be read by |
68
|
|
|
:mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing |
69
|
|
|
DICOM files. If :attr:`tensor` is given, the data in |
70
|
|
|
:attr:`path` will not be read. |
71
|
|
|
If a sequence of paths is given, data |
72
|
|
|
will be concatenated on the channel dimension so spatial |
73
|
|
|
dimensions must match. |
74
|
|
|
type: Type of image, such as :attr:`torchio.INTENSITY` or |
75
|
|
|
:attr:`torchio.LABEL`. This will be used by the transforms to |
76
|
|
|
decide whether to apply an operation, or which interpolation to use |
77
|
|
|
when resampling. For example, `preprocessing`_ and `augmentation`_ |
78
|
|
|
intensity transforms will only be applied to images with type |
79
|
|
|
:attr:`torchio.INTENSITY`. Spatial transforms will be applied to |
80
|
|
|
all types, and nearest neighbor interpolation is always used to |
81
|
|
|
resample images with type :attr:`torchio.LABEL`. |
82
|
|
|
The type :attr:`torchio.SAMPLING_MAP` may be used with instances of |
83
|
|
|
:class:`~torchio.data.sampler.weighted.WeightedSampler`. |
84
|
|
|
tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D |
85
|
|
|
:class:`torch.Tensor` or NumPy array with dimensions |
86
|
|
|
:math:`(C, W, H, D)`. |
87
|
|
|
affine: :math:`4 \times 4` matrix to convert voxel coordinates to world |
88
|
|
|
coordinates. If ``None``, an identity matrix will be used. See the |
89
|
|
|
`NiBabel docs on coordinates`_ for more information. |
90
|
|
|
check_nans: If ``True``, issues a warning if NaNs are found |
91
|
|
|
in the image. If ``False``, images will not be checked for the |
92
|
|
|
presence of NaNs. |
93
|
|
|
reader: Callable object that takes a path and returns a 4D tensor and a |
94
|
|
|
2D, :math:`4 \times 4` affine matrix. This can be used if your data |
95
|
|
|
is saved in a custom format, such as ``.npy`` (see example below). |
96
|
|
|
If the affine matrix is ``None``, an identity matrix will be used. |
97
|
|
|
**kwargs: Items that will be added to the image dictionary, e.g. |
98
|
|
|
acquisition parameters or image ID. |
99
|
|
|
verify_path: If ``True``, the path will be checked to see if it exists. If |
100
|
|
|
``False``, the path will not be verified. This is useful when it is |
101
|
|
|
expensive to check the path, e.g., when reading a large dataset from a |
102
|
|
|
mounted drive. |
103
|
|
|
|
104
|
|
|
TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk |
105
|
|
|
when needed. |
106
|
|
|
|
107
|
|
|
Example: |
108
|
|
|
>>> import torchio as tio |
109
|
|
|
>>> import numpy as np |
110
|
|
|
>>> image = tio.ScalarImage('t1.nii.gz') # subclass of Image |
111
|
|
|
>>> image # not loaded yet |
112
|
|
|
ScalarImage(path: t1.nii.gz; type: intensity) |
113
|
|
|
>>> times_two = 2 * image.data # data is loaded and cached here |
114
|
|
|
>>> image |
115
|
|
|
ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity) |
116
|
|
|
>>> image.save('doubled_image.nii.gz') |
117
|
|
|
>>> def numpy_reader(path): |
118
|
|
|
... data = np.load(path).as_type(np.float32) |
119
|
|
|
... affine = np.eye(4) |
120
|
|
|
... return data, affine |
121
|
|
|
>>> image = tio.ScalarImage('t1.npy', reader=numpy_reader) |
122
|
|
|
|
123
|
|
|
.. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading |
124
|
|
|
.. _preprocessing: https://docs.torchio.org/transforms/preprocessing.html#intensity |
125
|
|
|
.. _augmentation: https://docs.torchio.org/transforms/augmentation.html#intensity |
126
|
|
|
.. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html |
127
|
|
|
.. _NiBabel docs on coordinates: https://nipy.org/nibabel/coordinate_systems.html#the-affine-matrix-as-a-transformation-between-spaces |
128
|
|
|
.. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems |
129
|
|
|
.. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained |
130
|
|
|
.. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html |
131
|
|
|
.. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm |
132
|
|
|
""" |
133
|
|
|
|
134
|
|
|
def __init__( |
135
|
|
|
self, |
136
|
|
|
path: TypePath | Sequence[TypePath] | None = None, |
137
|
|
|
type: str | None = None, # noqa: A002 |
138
|
|
|
tensor: TypeData | None = None, |
139
|
|
|
affine: TypeData | None = None, |
140
|
|
|
check_nans: bool = False, # removed by ITK by default |
141
|
|
|
reader: Callable[[TypePath], TypeDataAffine] = read_image, |
142
|
|
|
verify_path: bool = True, |
143
|
|
|
**kwargs: dict[str, Any], |
144
|
|
|
): |
145
|
|
|
self.check_nans = check_nans |
146
|
|
|
self.reader = reader |
147
|
|
|
|
148
|
|
|
if type is None: |
149
|
|
|
warnings.warn( |
150
|
|
|
'Not specifying the image type is deprecated and will be' |
151
|
|
|
' mandatory in the future. You can probably use' |
152
|
|
|
' tio.ScalarImage or tio.LabelMap instead', |
153
|
|
|
FutureWarning, |
154
|
|
|
stacklevel=2, |
155
|
|
|
) |
156
|
|
|
type = INTENSITY # noqa: A001 |
157
|
|
|
|
158
|
|
|
if path is None and tensor is None: |
159
|
|
|
raise ValueError('A value for path or tensor must be given') |
160
|
|
|
self._loaded = False |
161
|
|
|
|
162
|
|
|
tensor = self._parse_tensor(tensor) |
163
|
|
|
affine = self._parse_affine(affine) |
164
|
|
|
if tensor is not None: |
165
|
|
|
self.set_data(tensor) |
166
|
|
|
self.affine = affine |
167
|
|
|
self._loaded = True |
168
|
|
|
for key in PROTECTED_KEYS: |
169
|
|
|
if key in kwargs: |
170
|
|
|
message = f'Key "{key}" is reserved. Use a different one' |
171
|
|
|
raise ValueError(message) |
172
|
|
|
if 'channels_last' in kwargs: |
173
|
|
|
message = ( |
174
|
|
|
'The "channels_last" keyword argument is deprecated after' |
175
|
|
|
' https://github.com/TorchIO-project/torchio/pull/685 and will be' |
176
|
|
|
' removed in the future' |
177
|
|
|
) |
178
|
|
|
warnings.warn(message, FutureWarning, stacklevel=2) |
179
|
|
|
|
180
|
|
|
super().__init__(**kwargs) |
181
|
|
|
self.path = self._parse_path(path, verify=verify_path) |
182
|
|
|
|
183
|
|
|
self[PATH] = '' if self.path is None else str(self.path) |
184
|
|
|
self[STEM] = '' if self.path is None else get_stem(self.path) |
185
|
|
|
self[TYPE] = type |
186
|
|
|
|
187
|
|
|
def __repr__(self): |
188
|
|
|
properties = [] |
189
|
|
|
properties.extend( |
190
|
|
|
[ |
191
|
|
|
f'shape: {self.shape}', |
192
|
|
|
f'spacing: {self.get_spacing_string()}', |
193
|
|
|
f'orientation: {self.orientation_str}+', |
194
|
|
|
] |
195
|
|
|
) |
196
|
|
|
if self._loaded: |
197
|
|
|
properties.append(f'dtype: {self.data.type()}') |
198
|
|
|
natural = humanize.naturalsize(self.memory, binary=True) |
199
|
|
|
properties.append(f'memory: {natural}') |
200
|
|
|
else: |
201
|
|
|
properties.append(f'path: "{self.path}"') |
202
|
|
|
|
203
|
|
|
properties = '; '.join(properties) |
204
|
|
|
string = f'{self.__class__.__name__}({properties})' |
205
|
|
|
return string |
206
|
|
|
|
207
|
|
|
def __getitem__(self, item): |
208
|
|
|
if isinstance(item, (slice, int, tuple)): |
209
|
|
|
return self._crop_from_slices(item) |
210
|
|
|
|
211
|
|
|
if item in (DATA, AFFINE): |
212
|
|
|
if item not in self: |
213
|
|
|
self.load() |
214
|
|
|
return super().__getitem__(item) |
215
|
|
|
|
216
|
|
|
def __array__(self): |
217
|
|
|
return self.data.numpy() |
218
|
|
|
|
219
|
|
|
def __copy__(self): |
220
|
|
|
kwargs = { |
221
|
|
|
TYPE: self.type, |
222
|
|
|
PATH: self.path, |
223
|
|
|
} |
224
|
|
|
if self._loaded: |
225
|
|
|
kwargs[TENSOR] = self.data |
226
|
|
|
kwargs[AFFINE] = self.affine |
227
|
|
|
for key, value in self.items(): |
228
|
|
|
if key in PROTECTED_KEYS: |
229
|
|
|
continue |
230
|
|
|
kwargs[key] = value # should I copy? deepcopy? |
231
|
|
|
new_image_class = type(self) |
232
|
|
|
new_image = new_image_class( |
233
|
|
|
check_nans=self.check_nans, |
234
|
|
|
reader=self.reader, |
235
|
|
|
**kwargs, |
236
|
|
|
) |
237
|
|
|
return new_image |
238
|
|
|
|
239
|
|
|
@property |
240
|
|
|
def data(self) -> torch.Tensor: |
241
|
|
|
"""Tensor data (same as :class:`Image.tensor`).""" |
242
|
|
|
return self[DATA] |
243
|
|
|
|
244
|
|
|
@data.setter # type: ignore[misc] |
245
|
|
|
@deprecated(version='0.18.16', reason=deprecation_message) |
246
|
|
|
def data(self, tensor: TypeData): |
247
|
|
|
self.set_data(tensor) |
248
|
|
|
|
249
|
|
|
def set_data(self, tensor: TypeData): |
250
|
|
|
"""Store a 4D tensor in the :attr:`data` key and attribute. |
251
|
|
|
|
252
|
|
|
Args: |
253
|
|
|
tensor: 4D tensor with dimensions :math:`(C, W, H, D)`. |
254
|
|
|
""" |
255
|
|
|
self[DATA] = self._parse_tensor(tensor, none_ok=False) |
256
|
|
|
self._loaded = True |
257
|
|
|
|
258
|
|
|
@property |
259
|
|
|
def tensor(self) -> torch.Tensor: |
260
|
|
|
"""Tensor data (same as :class:`Image.data`).""" |
261
|
|
|
return self.data |
262
|
|
|
|
263
|
|
|
@property |
264
|
|
|
def affine(self) -> np.ndarray: |
265
|
|
|
"""Affine matrix to transform voxel indices into world coordinates.""" |
266
|
|
|
# If path is a dir (probably DICOM), just load the data |
267
|
|
|
# Same if it's a list of paths (used to create a 4D image) |
268
|
|
|
# Finally, if we use a custom reader, SimpleITK probably won't be able |
269
|
|
|
# to read the metadata, so we resort to loading everything into memory |
270
|
|
|
is_custom_reader = self.reader is not read_image |
271
|
|
|
if self._loaded or self._is_dir() or self._is_multipath() or is_custom_reader: |
272
|
|
|
affine = self[AFFINE] |
273
|
|
|
else: |
274
|
|
|
assert self.path is not None |
275
|
|
|
assert isinstance(self.path, (str, Path)) |
276
|
|
|
affine = read_affine(self.path) |
277
|
|
|
return affine |
278
|
|
|
|
279
|
|
|
@affine.setter |
280
|
|
|
def affine(self, matrix): |
281
|
|
|
self[AFFINE] = self._parse_affine(matrix) |
282
|
|
|
|
283
|
|
|
@property |
284
|
|
|
def type(self) -> str: # noqa: A003 |
285
|
|
|
return self[TYPE] |
286
|
|
|
|
287
|
|
|
@property |
288
|
|
|
def shape(self) -> TypeQuartetInt: |
289
|
|
|
"""Tensor shape as :math:`(C, W, H, D)`.""" |
290
|
|
|
custom_reader = self.reader is not read_image |
291
|
|
|
multipath = self._is_multipath() |
292
|
|
|
if isinstance(self.path, Path): |
293
|
|
|
is_dir = self.path.is_dir() |
294
|
|
|
shape: TypeQuartetInt |
295
|
|
|
if self._loaded or custom_reader or multipath or is_dir: |
|
|
|
|
296
|
|
|
channels, si, sj, sk = self.data.shape |
297
|
|
|
shape = channels, si, sj, sk |
298
|
|
|
else: |
299
|
|
|
assert isinstance(self.path, (str, Path)) |
300
|
|
|
shape = read_shape(self.path) |
301
|
|
|
return shape |
302
|
|
|
|
303
|
|
|
@property |
304
|
|
|
def spatial_shape(self) -> TypeTripletInt: |
305
|
|
|
"""Tensor spatial shape as :math:`(W, H, D)`.""" |
306
|
|
|
return self.shape[1:] |
307
|
|
|
|
308
|
|
|
def check_is_2d(self) -> None: |
309
|
|
|
if not self.is_2d(): |
310
|
|
|
message = f'Image is not 2D. Spatial shape: {self.spatial_shape}' |
311
|
|
|
raise RuntimeError(message) |
312
|
|
|
|
313
|
|
|
@property |
314
|
|
|
def height(self) -> int: |
315
|
|
|
"""Image height, if 2D.""" |
316
|
|
|
self.check_is_2d() |
317
|
|
|
return self.spatial_shape[1] |
318
|
|
|
|
319
|
|
|
@property |
320
|
|
|
def width(self) -> int: |
321
|
|
|
"""Image width, if 2D.""" |
322
|
|
|
self.check_is_2d() |
323
|
|
|
return self.spatial_shape[0] |
324
|
|
|
|
325
|
|
|
@property |
326
|
|
|
def orientation(self) -> tuple[str, str, str]: |
327
|
|
|
"""Orientation codes.""" |
328
|
|
|
return nib.orientations.aff2axcodes(self.affine) |
329
|
|
|
|
330
|
|
|
@property |
331
|
|
|
def orientation_str(self) -> str: |
332
|
|
|
"""Orientation as a string.""" |
333
|
|
|
return ''.join(self.orientation) |
334
|
|
|
|
335
|
|
|
@property |
336
|
|
|
def direction(self) -> TypeDirection3D: |
337
|
|
|
_, _, direction = get_sitk_metadata_from_ras_affine( |
338
|
|
|
self.affine, |
339
|
|
|
lps=False, |
340
|
|
|
) |
341
|
|
|
return direction # type: ignore[return-value] |
342
|
|
|
|
343
|
|
|
@property |
344
|
|
|
def spacing(self) -> tuple[float, float, float]: |
345
|
|
|
"""Voxel spacing in mm.""" |
346
|
|
|
_, spacing = get_rotation_and_spacing_from_affine(self.affine) |
347
|
|
|
sx, sy, sz = spacing |
348
|
|
|
return float(sx), float(sy), float(sz) |
349
|
|
|
|
350
|
|
|
@property |
351
|
|
|
def origin(self) -> tuple[float, float, float]: |
352
|
|
|
"""Center of first voxel in array, in mm.""" |
353
|
|
|
ox, oy, oz = self.affine[:3, 3] |
354
|
|
|
return ox, oy, oz |
355
|
|
|
|
356
|
|
|
@property |
357
|
|
|
def itemsize(self): |
358
|
|
|
"""Element size of the data type.""" |
359
|
|
|
return self.data.element_size() |
360
|
|
|
|
361
|
|
|
@property |
362
|
|
|
def memory(self) -> float: |
363
|
|
|
"""Number of Bytes that the tensor takes in the RAM.""" |
364
|
|
|
return np.prod(self.shape) * self.itemsize |
365
|
|
|
|
366
|
|
|
@property |
367
|
|
|
def bounds(self) -> np.ndarray: |
368
|
|
|
"""Position of centers of voxels in smallest and largest indices.""" |
369
|
|
|
ini = 0, 0, 0 |
370
|
|
|
fin = np.array(self.spatial_shape) - 1 |
371
|
|
|
point_ini = apply_affine(self.affine, ini) |
372
|
|
|
point_fin = apply_affine(self.affine, fin) |
373
|
|
|
return np.array((point_ini, point_fin)) |
374
|
|
|
|
375
|
|
|
@property |
376
|
|
|
def num_channels(self) -> int: |
377
|
|
|
"""Get the number of channels in the associated 4D tensor.""" |
378
|
|
|
return len(self.data) |
379
|
|
|
|
380
|
|
|
def axis_name_to_index(self, axis: str) -> int: |
381
|
|
|
"""Convert an axis name to an axis index. |
382
|
|
|
|
383
|
|
|
Args: |
384
|
|
|
axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``, |
385
|
|
|
``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case |
386
|
|
|
versions and first letters are also valid, as only the first |
387
|
|
|
letter will be used. |
388
|
|
|
|
389
|
|
|
.. note:: If you are working with animals, you should probably use |
390
|
|
|
``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'`` |
391
|
|
|
for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``, |
392
|
|
|
respectively. |
393
|
|
|
|
394
|
|
|
.. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``, |
395
|
|
|
``'Left'`` and ``'Right'``. |
396
|
|
|
""" |
397
|
|
|
# Top and bottom are used for the vertical 2D axis as the use of |
398
|
|
|
# Height vs Horizontal might be ambiguous |
399
|
|
|
|
400
|
|
|
if not isinstance(axis, str): |
401
|
|
|
raise ValueError('Axis must be a string') |
402
|
|
|
axis = axis[0].upper() |
403
|
|
|
|
404
|
|
|
# Generally, TorchIO tensors are (C, W, H, D) |
405
|
|
|
if axis in 'TB': # Top, Bottom |
406
|
|
|
return -2 |
407
|
|
|
else: |
408
|
|
|
try: |
409
|
|
|
index = self.orientation.index(axis) |
410
|
|
|
except ValueError: |
411
|
|
|
index = self.orientation.index(self.flip_axis(axis)) |
412
|
|
|
# Return negative indices so that it does not matter whether we |
413
|
|
|
# refer to spatial dimensions or not |
414
|
|
|
index = -3 + index |
415
|
|
|
return index |
416
|
|
|
|
417
|
|
|
@staticmethod |
418
|
|
|
def flip_axis(axis: str) -> str: |
419
|
|
|
"""Return the opposite axis label. For example, ``'L'`` -> ``'R'``. |
420
|
|
|
|
421
|
|
|
Args: |
422
|
|
|
axis: Axis label, such as ``'L'`` or ``'left'``. |
423
|
|
|
""" |
424
|
|
|
labels = 'LRPAISTBDV' |
425
|
|
|
first = labels[::2] |
426
|
|
|
last = labels[1::2] |
427
|
|
|
flip_dict = dict(zip(first + last, last + first)) |
428
|
|
|
axis = axis[0].upper() |
429
|
|
|
flipped_axis = flip_dict.get(axis) |
430
|
|
|
if flipped_axis is None: |
431
|
|
|
values = ', '.join(labels) |
432
|
|
|
message = f'Axis not understood. Please use one of: {values}' |
433
|
|
|
raise ValueError(message) |
434
|
|
|
return flipped_axis |
435
|
|
|
|
436
|
|
|
def get_spacing_string(self) -> str: |
437
|
|
|
strings = [f'{n:.2f}' for n in self.spacing] |
438
|
|
|
string = f'({", ".join(strings)})' |
439
|
|
|
return string |
440
|
|
|
|
441
|
|
|
def get_bounds(self) -> TypeBounds: |
442
|
|
|
"""Get minimum and maximum world coordinates occupied by the image.""" |
443
|
|
|
first_index = 3 * (-0.5,) |
444
|
|
|
last_index = np.array(self.spatial_shape) - 0.5 |
445
|
|
|
first_point = apply_affine(self.affine, first_index) |
446
|
|
|
last_point = apply_affine(self.affine, last_index) |
447
|
|
|
array = np.array((first_point, last_point)) |
448
|
|
|
bounds_x, bounds_y, bounds_z = array.T.tolist() # type: ignore[misc] |
449
|
|
|
return bounds_x, bounds_y, bounds_z # type: ignore[return-value] |
450
|
|
|
|
451
|
|
|
def _parse_single_path( |
452
|
|
|
self, |
453
|
|
|
path: TypePath, |
454
|
|
|
*, |
455
|
|
|
verify: bool = True, |
456
|
|
|
) -> Path: |
457
|
|
|
if isinstance(path, (torch.Tensor, np.ndarray)): |
458
|
|
|
class_name = self.__class__.__name__ |
459
|
|
|
message = ( |
460
|
|
|
'Expected type str or Path but found a tensor/array. Instead of' |
461
|
|
|
f' {class_name}(your_tensor),' |
462
|
|
|
f' use {class_name}(tensor=your_tensor).' |
463
|
|
|
) |
464
|
|
|
raise TypeError(message) |
465
|
|
|
try: |
466
|
|
|
path = Path(path).expanduser() |
467
|
|
|
except TypeError as err: |
468
|
|
|
message = ( |
469
|
|
|
f'Expected type str or Path but found an object with type' |
470
|
|
|
f' {type(path)} instead' |
471
|
|
|
) |
472
|
|
|
raise TypeError(message) from err |
473
|
|
|
except RuntimeError as err: |
474
|
|
|
message = f'Conversion to path not possible for variable: {path}' |
475
|
|
|
raise RuntimeError(message) from err |
476
|
|
|
if not verify: |
477
|
|
|
return path |
478
|
|
|
|
479
|
|
|
if not (path.is_file() or path.is_dir()): # might be a dir with DICOM |
480
|
|
|
raise FileNotFoundError(f'File not found: "{path}"') |
481
|
|
|
return path |
482
|
|
|
|
483
|
|
|
def _parse_path( |
484
|
|
|
self, |
485
|
|
|
path: TypePath | Sequence[TypePath] | None, |
486
|
|
|
*, |
487
|
|
|
verify: bool = True, |
488
|
|
|
) -> Path | list[Path] | None: |
489
|
|
|
if path is None: |
490
|
|
|
return None |
491
|
|
|
elif isinstance(path, dict): |
492
|
|
|
# https://github.com/TorchIO-project/torchio/pull/838 |
493
|
|
|
raise TypeError('The path argument cannot be a dictionary') |
494
|
|
|
elif self._is_paths_sequence(path): |
495
|
|
|
return [self._parse_single_path(p, verify=verify) for p in path] # type: ignore[union-attr] |
496
|
|
|
else: |
497
|
|
|
return self._parse_single_path(path, verify=verify) # type: ignore[arg-type] |
498
|
|
|
|
499
|
|
|
def _parse_tensor( |
500
|
|
|
self, |
501
|
|
|
tensor: TypeData | None, |
502
|
|
|
none_ok: bool = True, |
503
|
|
|
) -> torch.Tensor | None: |
504
|
|
|
if tensor is None: |
505
|
|
|
if none_ok: |
506
|
|
|
return None |
507
|
|
|
else: |
508
|
|
|
raise RuntimeError('Input tensor cannot be None') |
509
|
|
|
if isinstance(tensor, np.ndarray): |
510
|
|
|
tensor = check_uint_to_int(tensor) |
511
|
|
|
tensor = torch.as_tensor(tensor) |
512
|
|
|
elif not isinstance(tensor, torch.Tensor): |
513
|
|
|
message = ( |
514
|
|
|
'Input tensor must be a PyTorch tensor or NumPy array,' |
515
|
|
|
f' but type "{type(tensor)}" was found' |
516
|
|
|
) |
517
|
|
|
raise TypeError(message) |
518
|
|
|
ndim = tensor.ndim |
519
|
|
|
if ndim != 4: |
520
|
|
|
raise ValueError(f'Input tensor must be 4D, but it is {ndim}D') |
521
|
|
|
if tensor.dtype == torch.bool: |
522
|
|
|
tensor = tensor.to(torch.uint8) |
523
|
|
|
if self.check_nans and torch.isnan(tensor).any(): |
524
|
|
|
warnings.warn('NaNs found in tensor', RuntimeWarning, stacklevel=2) |
525
|
|
|
return tensor |
526
|
|
|
|
527
|
|
|
@staticmethod |
528
|
|
|
def _parse_tensor_shape(tensor: torch.Tensor) -> TypeData: |
529
|
|
|
return ensure_4d(tensor) |
530
|
|
|
|
531
|
|
|
@staticmethod |
532
|
|
|
def _parse_affine(affine: TypeData | None) -> np.ndarray: |
533
|
|
|
if affine is None: |
534
|
|
|
return np.eye(4) |
535
|
|
|
if isinstance(affine, torch.Tensor): |
536
|
|
|
affine = affine.numpy() |
537
|
|
|
if not isinstance(affine, np.ndarray): |
538
|
|
|
bad_type = type(affine) |
539
|
|
|
raise TypeError(f'Affine must be a NumPy array, not {bad_type}') |
540
|
|
|
if affine.shape != (4, 4): |
541
|
|
|
bad_shape = affine.shape |
542
|
|
|
raise ValueError(f'Affine shape must be (4, 4), not {bad_shape}') |
543
|
|
|
return affine.astype(np.float64) |
544
|
|
|
|
545
|
|
|
@staticmethod |
546
|
|
|
def _is_paths_sequence(path: TypePath | Sequence[TypePath] | None) -> bool: |
547
|
|
|
is_not_string = not isinstance(path, str) |
548
|
|
|
return is_not_string and is_iterable(path) |
549
|
|
|
|
550
|
|
|
def _is_multipath(self) -> bool: |
551
|
|
|
return self._is_paths_sequence(self.path) |
552
|
|
|
|
553
|
|
|
def _is_dir(self) -> bool: |
554
|
|
|
is_sequence = self._is_multipath() |
555
|
|
|
if is_sequence: |
556
|
|
|
return False |
557
|
|
|
elif self.path is None: |
558
|
|
|
return False |
559
|
|
|
else: |
560
|
|
|
assert isinstance(self.path, Path) |
561
|
|
|
return self.path.is_dir() |
562
|
|
|
|
563
|
|
|
def load(self) -> None: |
564
|
|
|
r"""Load the image from disk. |
565
|
|
|
|
566
|
|
|
Returns: |
567
|
|
|
Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D |
568
|
|
|
:math:`4 \times 4` affine matrix to convert voxel indices to world |
569
|
|
|
coordinates. |
570
|
|
|
""" |
571
|
|
|
if self._loaded: |
572
|
|
|
return |
573
|
|
|
|
574
|
|
|
paths: list[Path] |
575
|
|
|
if self._is_multipath(): |
576
|
|
|
paths = self.path # type: ignore[assignment] |
577
|
|
|
else: |
578
|
|
|
paths = [self.path] # type: ignore[list-item] |
579
|
|
|
tensor, affine = self.read_and_check(paths[0]) |
580
|
|
|
tensors = [tensor] |
581
|
|
|
for path in paths[1:]: |
582
|
|
|
new_tensor, new_affine = self.read_and_check(path) |
583
|
|
|
if not np.array_equal(affine, new_affine): |
584
|
|
|
message = ( |
585
|
|
|
'Files have different affine matrices.' |
586
|
|
|
f'\nMatrix of {paths[0]}:' |
587
|
|
|
f'\n{affine}' |
588
|
|
|
f'\nMatrix of {path}:' |
589
|
|
|
f'\n{new_affine}' |
590
|
|
|
) |
591
|
|
|
warnings.warn(message, RuntimeWarning, stacklevel=2) |
592
|
|
|
if not tensor.shape[1:] == new_tensor.shape[1:]: |
593
|
|
|
message = ( |
594
|
|
|
f'Files shape do not match, found {tensor.shape}' |
595
|
|
|
f'and {new_tensor.shape}' |
596
|
|
|
) |
597
|
|
|
raise RuntimeError(message) |
598
|
|
|
tensors.append(new_tensor) |
599
|
|
|
tensor = torch.cat(tensors) |
600
|
|
|
self.set_data(tensor) |
601
|
|
|
self.affine = affine |
602
|
|
|
self._loaded = True |
603
|
|
|
|
604
|
|
|
def unload(self) -> None: |
605
|
|
|
"""Unload the image from memory. |
606
|
|
|
|
607
|
|
|
Raises: |
608
|
|
|
RuntimeError: If the images has not been loaded yet or if no path |
609
|
|
|
is available. |
610
|
|
|
""" |
611
|
|
|
if not self._loaded: |
612
|
|
|
message = 'Image cannot be unloaded as it has not been loaded yet' |
613
|
|
|
raise RuntimeError(message) |
614
|
|
|
if self.path is None: |
615
|
|
|
message = ( |
616
|
|
|
'Cannot unload image as no path is available' |
617
|
|
|
' from where the image could be loaded again' |
618
|
|
|
) |
619
|
|
|
raise RuntimeError(message) |
620
|
|
|
self[DATA] = None |
621
|
|
|
self[AFFINE] = None |
622
|
|
|
self._loaded = False |
623
|
|
|
|
624
|
|
|
def read_and_check(self, path: TypePath) -> TypeDataAffine: |
625
|
|
|
tensor, affine = self.reader(path) |
626
|
|
|
# Make sure the data type is compatible with PyTorch |
627
|
|
|
if self.reader is not read_image and isinstance(tensor, np.ndarray): |
628
|
|
|
tensor = check_uint_to_int(tensor) |
629
|
|
|
tensor = self._parse_tensor_shape(tensor) # type: ignore[assignment] |
630
|
|
|
tensor = self._parse_tensor(tensor) # type: ignore[assignment] |
631
|
|
|
affine = self._parse_affine(affine) |
632
|
|
|
if self.check_nans and torch.isnan(tensor).any(): |
633
|
|
|
warnings.warn( |
634
|
|
|
f'NaNs found in file "{path}"', |
635
|
|
|
RuntimeWarning, |
636
|
|
|
stacklevel=2, |
637
|
|
|
) |
638
|
|
|
return tensor, affine |
639
|
|
|
|
640
|
|
|
def save(self, path: TypePath, squeeze: bool | None = None) -> None: |
641
|
|
|
"""Save image to disk. |
642
|
|
|
|
643
|
|
|
Args: |
644
|
|
|
path: String or instance of :class:`pathlib.Path`. |
645
|
|
|
squeeze: Whether to remove singleton dimensions before saving. |
646
|
|
|
If ``None``, the array will be squeezed if the output format is |
647
|
|
|
JP(E)G, PNG, BMP or TIF(F). |
648
|
|
|
""" |
649
|
|
|
write_image( |
650
|
|
|
self.data, |
651
|
|
|
self.affine, |
652
|
|
|
path, |
653
|
|
|
squeeze=squeeze, |
654
|
|
|
) |
655
|
|
|
|
656
|
|
|
def is_2d(self) -> bool: |
657
|
|
|
return self.shape[-1] == 1 |
658
|
|
|
|
659
|
|
|
def numpy(self) -> np.ndarray: |
660
|
|
|
"""Get a NumPy array containing the image data.""" |
661
|
|
|
return np.asarray(self) |
662
|
|
|
|
663
|
|
|
def as_sitk(self, **kwargs) -> sitk.Image: |
664
|
|
|
"""Get the image as an instance of :class:`sitk.Image`.""" |
665
|
|
|
return nib_to_sitk(self.data, self.affine, **kwargs) |
666
|
|
|
|
667
|
|
|
@classmethod |
668
|
|
|
def from_sitk(cls, sitk_image): |
669
|
|
|
"""Instantiate a new TorchIO image from a :class:`sitk.Image`. |
670
|
|
|
|
671
|
|
|
Example: |
672
|
|
|
>>> import torchio as tio |
673
|
|
|
>>> import SimpleITK as sitk |
674
|
|
|
>>> sitk_image = sitk.Image(20, 30, 40, sitk.sitkUInt16) |
675
|
|
|
>>> tio.LabelMap.from_sitk(sitk_image) |
676
|
|
|
LabelMap(shape: (1, 20, 30, 40); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 93.8 KiB; dtype: torch.IntTensor) |
677
|
|
|
>>> sitk_image = sitk.Image((224, 224), sitk.sitkVectorFloat32, 3) |
678
|
|
|
>>> tio.ScalarImage.from_sitk(sitk_image) |
679
|
|
|
ScalarImage(shape: (3, 224, 224, 1); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 588.0 KiB; dtype: torch.FloatTensor) |
680
|
|
|
""" |
681
|
|
|
tensor, affine = sitk_to_nib(sitk_image) |
682
|
|
|
return cls(tensor=tensor, affine=affine) |
683
|
|
|
|
684
|
|
|
def as_pil(self, transpose=True): |
685
|
|
|
"""Get the image as an instance of :class:`PIL.Image`. |
686
|
|
|
|
687
|
|
|
.. note:: Values will be clamped to 0-255 and cast to uint8. |
688
|
|
|
|
689
|
|
|
.. note:: To use this method, Pillow needs to be installed: |
690
|
|
|
``pip install Pillow``. |
691
|
|
|
""" |
692
|
|
|
try: |
693
|
|
|
from PIL import Image as ImagePIL |
694
|
|
|
except ModuleNotFoundError as e: |
695
|
|
|
message = 'Please install Pillow to use Image.as_pil(): pip install Pillow' |
696
|
|
|
raise RuntimeError(message) from e |
697
|
|
|
|
698
|
|
|
self.check_is_2d() |
699
|
|
|
tensor = self.data |
700
|
|
|
if len(tensor) not in (1, 3, 4): |
701
|
|
|
raise NotImplementedError( |
702
|
|
|
'Only 1, 3 or 4 channels are supported for conversion to Pillow image' |
703
|
|
|
) |
704
|
|
|
if len(tensor) == 1: |
705
|
|
|
tensor = torch.cat(3 * [tensor]) |
706
|
|
|
if transpose: |
707
|
|
|
tensor = tensor.permute(3, 2, 1, 0) |
708
|
|
|
else: |
709
|
|
|
tensor = tensor.permute(3, 1, 2, 0) |
710
|
|
|
array = tensor.clamp(0, 255).numpy()[0] |
711
|
|
|
return ImagePIL.fromarray(array.astype(np.uint8)) |
712
|
|
|
|
713
|
|
|
def to_gif( |
714
|
|
|
self, |
715
|
|
|
axis: int, |
716
|
|
|
duration: float, # of full gif |
717
|
|
|
output_path: TypePath, |
718
|
|
|
loop: int = 0, |
719
|
|
|
rescale: bool = True, |
720
|
|
|
optimize: bool = True, |
721
|
|
|
reverse: bool = False, |
722
|
|
|
) -> None: |
723
|
|
|
"""Save an animated GIF of the image. |
724
|
|
|
|
725
|
|
|
Args: |
726
|
|
|
axis: Spatial axis (0, 1 or 2). |
727
|
|
|
duration: Duration of the full animation in seconds. |
728
|
|
|
output_path: Path to the output GIF file. |
729
|
|
|
loop: Number of times the GIF should loop. |
730
|
|
|
``0`` means that it will loop forever. |
731
|
|
|
rescale: Use :class:`~torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity` |
732
|
|
|
to rescale the intensity values to :math:`[0, 255]`. |
733
|
|
|
optimize: If ``True``, attempt to compress the palette by |
734
|
|
|
eliminating unused colors. This is only useful if the palette |
735
|
|
|
can be compressed to the next smaller power of 2 elements. |
736
|
|
|
reverse: Reverse the temporal order of frames. |
737
|
|
|
""" |
738
|
|
|
from ..visualization import make_gif # avoid circular import |
739
|
|
|
|
740
|
|
|
make_gif( |
741
|
|
|
self.data, |
742
|
|
|
axis, |
743
|
|
|
duration, |
744
|
|
|
output_path, |
745
|
|
|
loop=loop, |
746
|
|
|
rescale=rescale, |
747
|
|
|
optimize=optimize, |
748
|
|
|
reverse=reverse, |
749
|
|
|
) |
750
|
|
|
|
751
|
|
|
def to_ras(self) -> Image: |
752
|
|
|
if self.orientation_str != 'RAS': |
753
|
|
|
from ..transforms.preprocessing.spatial.to_canonical import ToCanonical |
754
|
|
|
|
755
|
|
|
return ToCanonical()(self) |
756
|
|
|
return self |
757
|
|
|
|
758
|
|
|
def get_center(self, lps: bool = False) -> TypeTripletFloat: |
759
|
|
|
"""Get image center in RAS+ or LPS+ coordinates. |
760
|
|
|
|
761
|
|
|
Args: |
762
|
|
|
lps: If ``True``, the coordinates will be in LPS+ orientation, i.e. |
763
|
|
|
the first dimension grows towards the left, etc. Otherwise, the |
764
|
|
|
coordinates will be in RAS+ orientation. |
765
|
|
|
""" |
766
|
|
|
size = np.array(self.spatial_shape) |
767
|
|
|
center_index = (size - 1) / 2 |
768
|
|
|
r, a, s = apply_affine(self.affine, center_index) |
769
|
|
|
if lps: |
770
|
|
|
return (-r, -a, s) |
771
|
|
|
else: |
772
|
|
|
return (r, a, s) |
773
|
|
|
|
774
|
|
|
def set_check_nans(self, check_nans: bool) -> None: |
775
|
|
|
self.check_nans = check_nans |
776
|
|
|
|
777
|
|
|
def plot(self, **kwargs) -> None: |
778
|
|
|
"""Plot image.""" |
779
|
|
|
if self.is_2d(): |
780
|
|
|
self.as_pil().show() |
781
|
|
|
else: |
782
|
|
|
from ..visualization import plot_volume # avoid circular import |
783
|
|
|
|
784
|
|
|
plot_volume(self, **kwargs) |
785
|
|
|
|
786
|
|
|
def show(self, viewer_path: TypePath | None = None) -> None: |
787
|
|
|
"""Open the image using external software. |
788
|
|
|
|
789
|
|
|
Args: |
790
|
|
|
viewer_path: Path to the application used to view the image. If |
791
|
|
|
``None``, the value of the environment variable |
792
|
|
|
``SITK_SHOW_COMMAND`` will be used. If this variable is also |
793
|
|
|
not set, TorchIO will try to guess the location of |
794
|
|
|
`ITK-SNAP <http://www.itksnap.org/pmwiki/pmwiki.php>`_ and |
795
|
|
|
`3D Slicer <https://www.slicer.org/>`_. |
796
|
|
|
|
797
|
|
|
Raises: |
798
|
|
|
RuntimeError: If the viewer is not found. |
799
|
|
|
""" |
800
|
|
|
sitk_image = self.as_sitk() |
801
|
|
|
image_viewer = sitk.ImageViewer() |
802
|
|
|
# This is so that 3D Slicer creates segmentation nodes from label maps |
803
|
|
|
if self.__class__.__name__ == 'LabelMap': |
804
|
|
|
image_viewer.SetFileExtension('.seg.nrrd') |
805
|
|
|
if viewer_path is not None: |
806
|
|
|
image_viewer.SetApplication(str(viewer_path)) |
807
|
|
|
try: |
808
|
|
|
image_viewer.Execute(sitk_image) |
809
|
|
|
except RuntimeError as e: |
810
|
|
|
viewer_path = guess_external_viewer() |
811
|
|
|
if viewer_path is None: |
812
|
|
|
message = ( |
813
|
|
|
'No external viewer has been found. Please set the' |
814
|
|
|
' environment variable SITK_SHOW_COMMAND to a viewer of' |
815
|
|
|
' your choice' |
816
|
|
|
) |
817
|
|
|
raise RuntimeError(message) from e |
818
|
|
|
image_viewer.SetApplication(str(viewer_path)) |
819
|
|
|
image_viewer.Execute(sitk_image) |
820
|
|
|
|
821
|
|
|
def _crop_from_slices( |
822
|
|
|
self, |
823
|
|
|
slices: TypeSlice | tuple[TypeSlice, ...], |
824
|
|
|
) -> Image: |
825
|
|
|
from ..transforms import Crop |
826
|
|
|
|
827
|
|
|
slices_tuple = to_tuple(slices) # type: ignore[assignment] |
828
|
|
|
cropping: list[int] = [] |
829
|
|
|
for dim, slice_ in enumerate(slices_tuple): |
830
|
|
|
if isinstance(slice_, slice): |
831
|
|
|
pass |
832
|
|
|
elif slice_ is Ellipsis: |
833
|
|
|
message = 'Ellipsis slicing is not supported yet' |
834
|
|
|
raise NotImplementedError(message) |
835
|
|
|
elif isinstance(slice_, int): |
836
|
|
|
slice_ = slice(slice_, slice_ + 1) # type: ignore[assignment] |
837
|
|
|
else: |
838
|
|
|
message = f'Slice type not understood: "{type(slice_)}"' |
839
|
|
|
raise TypeError(message) |
840
|
|
|
shape_dim = self.spatial_shape[dim] |
841
|
|
|
assert isinstance(slice_, slice) |
842
|
|
|
start, stop, step = slice_.indices(shape_dim) |
843
|
|
|
if step != 1: |
844
|
|
|
message = ( |
845
|
|
|
'Slicing with steps different from 1 is not supported yet.' |
846
|
|
|
' Use the Crop transform instead' |
847
|
|
|
) |
848
|
|
|
raise ValueError(message) |
849
|
|
|
crop_ini = start |
850
|
|
|
crop_fin = shape_dim - stop |
851
|
|
|
cropping.extend([crop_ini, crop_fin]) |
852
|
|
|
while dim < 2: |
|
|
|
|
853
|
|
|
cropping.extend([0, 0]) |
854
|
|
|
dim += 1 |
855
|
|
|
w_ini, w_fin, h_ini, h_fin, d_ini, d_fin = cropping |
856
|
|
|
cropping_arg = w_ini, w_fin, h_ini, h_fin, d_ini, d_fin # making mypy happy |
857
|
|
|
return Crop(cropping_arg)(self) # type: ignore[return-value] |
858
|
|
|
|
859
|
|
|
|
860
|
|
|
class ScalarImage(Image): |
861
|
|
|
"""Image whose pixel values represent scalars. |
862
|
|
|
|
863
|
|
|
Example: |
864
|
|
|
>>> import torch |
865
|
|
|
>>> import torchio as tio |
866
|
|
|
>>> # Loading from a file |
867
|
|
|
>>> t1_image = tio.ScalarImage('t1.nii.gz') |
868
|
|
|
>>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88)) |
869
|
|
|
>>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False) |
870
|
|
|
>>> data, affine = image.data, image.affine |
871
|
|
|
>>> affine.shape |
872
|
|
|
(4, 4) |
873
|
|
|
>>> image.data is image[tio.DATA] |
874
|
|
|
True |
875
|
|
|
>>> image.data is image.tensor |
876
|
|
|
True |
877
|
|
|
>>> type(image.data) |
878
|
|
|
torch.Tensor |
879
|
|
|
|
880
|
|
|
See :class:`~torchio.Image` for more information. |
881
|
|
|
""" |
882
|
|
|
|
883
|
|
|
def __init__(self, *args, **kwargs): |
884
|
|
|
if 'type' in kwargs and kwargs['type'] != INTENSITY: |
885
|
|
|
raise ValueError('Type of ScalarImage is always torchio.INTENSITY') |
886
|
|
|
kwargs.update({'type': INTENSITY}) |
887
|
|
|
super().__init__(*args, **kwargs) |
888
|
|
|
|
889
|
|
|
def hist(self, **kwargs) -> None: |
890
|
|
|
"""Plot histogram.""" |
891
|
|
|
from ..visualization import plot_histogram |
892
|
|
|
|
893
|
|
|
x = self.data.flatten().numpy() |
894
|
|
|
plot_histogram(x, **kwargs) |
895
|
|
|
|
896
|
|
|
def to_video( |
897
|
|
|
self, |
898
|
|
|
output_path: TypePath, |
899
|
|
|
frame_rate: float | None = 15, |
900
|
|
|
seconds: float | None = None, |
901
|
|
|
direction: str = 'I', |
902
|
|
|
verbosity: str = 'error', |
903
|
|
|
) -> None: |
904
|
|
|
"""Create a video showing all image slices along a specified direction. |
905
|
|
|
|
906
|
|
|
Args: |
907
|
|
|
output_path: Path to the output video file. |
908
|
|
|
frame_rate: Number of frames per second (FPS). |
909
|
|
|
seconds: Target duration of the full video. |
910
|
|
|
direction: |
911
|
|
|
verbosity: |
912
|
|
|
|
913
|
|
|
.. note:: Only ``frame_rate`` or ``seconds`` may (and must) be specified. |
914
|
|
|
""" |
915
|
|
|
from ..visualization import make_video # avoid circular import |
916
|
|
|
|
917
|
|
|
make_video( |
918
|
|
|
self.to_ras(), # type: ignore[arg-type] |
919
|
|
|
output_path, |
920
|
|
|
frame_rate=frame_rate, |
921
|
|
|
seconds=seconds, |
922
|
|
|
direction=direction, |
923
|
|
|
verbosity=verbosity, |
924
|
|
|
) |
925
|
|
|
|
926
|
|
|
|
927
|
|
|
class LabelMap(Image): |
928
|
|
|
"""Image whose pixel values represent segmentation labels. |
929
|
|
|
|
930
|
|
|
A sequence of paths to 3D images can be passed to create a 4D image. |
931
|
|
|
This is useful to create a |
932
|
|
|
`tissue probability map (TPM) <https://andysbrainbook.readthedocs.io/en/latest/SPM/SPM_Short_Course/SPM_04_Preprocessing/04_SPM_Segmentation.html#tissue-probability-maps>`, |
933
|
|
|
which contains the probability of each voxel belonging to a certain tissue type, |
934
|
|
|
or to create a label map with overlapping labels. |
935
|
|
|
|
936
|
|
|
Intensity transforms are not applied to these images. |
937
|
|
|
|
938
|
|
|
Nearest neighbor interpolation is always used to resample label maps, |
939
|
|
|
independently of the specified interpolation type in the transform |
940
|
|
|
instantiation. |
941
|
|
|
|
942
|
|
|
Example: |
943
|
|
|
>>> import torch |
944
|
|
|
>>> import torchio as tio |
945
|
|
|
>>> binary_tensor = torch.rand(1, 128, 128, 68) > 0.5 |
946
|
|
|
>>> label_map = tio.LabelMap(tensor=binary_tensor) # from a tensor |
947
|
|
|
>>> label_map = tio.LabelMap('t1_seg.nii.gz') # from a file |
948
|
|
|
>>> # Create a 4D tissue probability map from different 3D images |
949
|
|
|
>>> tissues = 'gray_matter.nii.gz', 'white_matter.nii.gz', 'csf.nii.gz' |
950
|
|
|
>>> tpm = tio.LabelMap(tissues) |
951
|
|
|
|
952
|
|
|
See :class:`~torchio.Image` for more information. |
953
|
|
|
""" |
954
|
|
|
|
955
|
|
|
def __init__(self, *args, **kwargs): |
956
|
|
|
if 'type' in kwargs and kwargs['type'] != LABEL: |
957
|
|
|
raise ValueError('Type of LabelMap is always torchio.LABEL') |
958
|
|
|
kwargs.update({'type': LABEL}) |
959
|
|
|
super().__init__(*args, **kwargs) |
960
|
|
|
|
961
|
|
|
def count_nonzero(self) -> int: |
962
|
|
|
"""Get the number of voxels that are not 0.""" |
963
|
|
|
return int(self.data.count_nonzero().item()) |
964
|
|
|
|
965
|
|
|
def count_labels(self) -> dict[int, int]: |
966
|
|
|
"""Get the number of voxels in each label.""" |
967
|
|
|
values_list = self.data.flatten().tolist() |
968
|
|
|
counter = Counter(values_list) |
969
|
|
|
counts = {label: counter[label] for label in sorted(counter)} |
970
|
|
|
return counts |
971
|
|
|
|