Passed
Pull Request — main (#1350)
by Fernando
01:23
created

torchio.data.image.Image.plot()   A

Complexity

Conditions 2

Size

Total Lines 8
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 5
nop 2
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
from __future__ import annotations
2
3
import warnings
4
from collections import Counter
5
from collections.abc import Sequence
6
from pathlib import Path
7
from typing import Any
8
from typing import Callable
9
10
import humanize
11
import nibabel as nib
12
import numpy as np
13
import SimpleITK as sitk
14
import torch
15
from deprecated import deprecated
16
from nibabel.affines import apply_affine
17
18
from ..constants import AFFINE
19
from ..constants import DATA
20
from ..constants import INTENSITY
21
from ..constants import LABEL
22
from ..constants import PATH
23
from ..constants import STEM
24
from ..constants import TENSOR
25
from ..constants import TYPE
26
from ..types import TypeData
27
from ..types import TypeDataAffine
28
from ..types import TypeDirection3D
29
from ..types import TypePath
30
from ..types import TypeQuartetInt
31
from ..types import TypeSlice
32
from ..types import TypeTripletFloat
33
from ..types import TypeTripletInt
34
from ..utils import get_stem
35
from ..utils import guess_external_viewer
36
from ..utils import is_iterable
37
from ..utils import to_tuple
38
from .io import check_uint_to_int
39
from .io import ensure_4d
40
from .io import get_rotation_and_spacing_from_affine
41
from .io import get_sitk_metadata_from_ras_affine
42
from .io import nib_to_sitk
43
from .io import read_affine
44
from .io import read_image
45
from .io import read_shape
46
from .io import sitk_to_nib
47
from .io import write_image
48
49
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
50
TypeBound = tuple[float, float]
51
TypeBounds = tuple[TypeBound, TypeBound, TypeBound]
52
53
deprecation_message = (
54
    'Setting the image data with the property setter is deprecated. Use the'
55
    ' set_data() method instead'
56
)
57
58
59
class Image(dict):
60
    r"""TorchIO image.
61
62
    For information about medical image orientation, check out `NiBabel docs`_,
63
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
64
    `SimpleITK docs`_.
65
66
    Args:
67
        path: Path to a file or sequence of paths to files that can be read by
68
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
69
            DICOM files. If :attr:`tensor` is given, the data in
70
            :attr:`path` will not be read.
71
            If a sequence of paths is given, data
72
            will be concatenated on the channel dimension so spatial
73
            dimensions must match.
74
        type: Type of image, such as :attr:`torchio.INTENSITY` or
75
            :attr:`torchio.LABEL`. This will be used by the transforms to
76
            decide whether to apply an operation, or which interpolation to use
77
            when resampling. For example, `preprocessing`_ and `augmentation`_
78
            intensity transforms will only be applied to images with type
79
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
80
            all types, and nearest neighbor interpolation is always used to
81
            resample images with type :attr:`torchio.LABEL`.
82
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
83
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
84
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
85
            :class:`torch.Tensor` or NumPy array with dimensions
86
            :math:`(C, W, H, D)`.
87
        affine: :math:`4 \times 4` matrix to convert voxel coordinates to world
88
            coordinates. If ``None``, an identity matrix will be used. See the
89
            `NiBabel docs on coordinates`_ for more information.
90
        check_nans: If ``True``, issues a warning if NaNs are found
91
            in the image. If ``False``, images will not be checked for the
92
            presence of NaNs.
93
        reader: Callable object that takes a path and returns a 4D tensor and a
94
            2D, :math:`4 \times 4` affine matrix. This can be used if your data
95
            is saved in a custom format, such as ``.npy`` (see example below).
96
            If the affine matrix is ``None``, an identity matrix will be used.
97
        **kwargs: Items that will be added to the image dictionary, e.g.
98
            acquisition parameters or image ID.
99
        verify_path: If ``True``, the path will be checked to see if it exists. If
100
            ``False``, the path will not be verified. This is useful when it is
101
            expensive to check the path, e.g., when reading a large dataset from a
102
            mounted drive.
103
104
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
105
    when needed.
106
107
    Example:
108
        >>> import torchio as tio
109
        >>> import numpy as np
110
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
111
        >>> image  # not loaded yet
112
        ScalarImage(path: t1.nii.gz; type: intensity)
113
        >>> times_two = 2 * image.data  # data is loaded and cached here
114
        >>> image
115
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
116
        >>> image.save('doubled_image.nii.gz')
117
        >>> def numpy_reader(path):
118
        ...     data = np.load(path).as_type(np.float32)
119
        ...     affine = np.eye(4)
120
        ...     return data, affine
121
        >>> image = tio.ScalarImage('t1.npy', reader=numpy_reader)
122
123
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
124
    .. _preprocessing: https://torchio-project.github.io/torchio/transforms/preprocessing.html#intensity
125
    .. _augmentation: https://torchio-project.github.io/torchio/transforms/augmentation.html#intensity
126
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
127
    .. _NiBabel docs on coordinates: https://nipy.org/nibabel/coordinate_systems.html#the-affine-matrix-as-a-transformation-between-spaces
128
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
129
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
130
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
131
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
132
    """
133
134
    def __init__(
135
        self,
136
        path: TypePath | Sequence[TypePath] | None = None,
137
        type: str | None = None,  # noqa: A002
138
        tensor: TypeData | None = None,
139
        affine: TypeData | None = None,
140
        check_nans: bool = False,  # removed by ITK by default
141
        reader: Callable[[TypePath], TypeDataAffine] = read_image,
142
        verify_path: bool = True,
143
        **kwargs: dict[str, Any],
144
    ):
145
        self.check_nans = check_nans
146
        self.reader = reader
147
148
        if type is None:
149
            warnings.warn(
150
                'Not specifying the image type is deprecated and will be'
151
                ' mandatory in the future. You can probably use'
152
                ' tio.ScalarImage or tio.LabelMap instead',
153
                FutureWarning,
154
                stacklevel=2,
155
            )
156
            type = INTENSITY  # noqa: A001
157
158
        if path is None and tensor is None:
159
            raise ValueError('A value for path or tensor must be given')
160
        self._loaded = False
161
162
        tensor = self._parse_tensor(tensor)
163
        affine = self._parse_affine(affine)
164
        if tensor is not None:
165
            self.set_data(tensor)
166
            self.affine = affine
167
            self._loaded = True
168
        for key in PROTECTED_KEYS:
169
            if key in kwargs:
170
                message = f'Key "{key}" is reserved. Use a different one'
171
                raise ValueError(message)
172
        if 'channels_last' in kwargs:
173
            message = (
174
                'The "channels_last" keyword argument is deprecated after'
175
                ' https://github.com/TorchIO-project/torchio/pull/685 and will be'
176
                ' removed in the future'
177
            )
178
            warnings.warn(message, FutureWarning, stacklevel=2)
179
180
        super().__init__(**kwargs)
181
        self.path = self._parse_path(path, verify=verify_path)
182
183
        self[PATH] = '' if self.path is None else str(self.path)
184
        self[STEM] = '' if self.path is None else get_stem(self.path)
185
        self[TYPE] = type
186
187
    def __repr__(self):
188
        properties = []
189
        properties.extend(
190
            [
191
                f'shape: {self.shape}',
192
                f'spacing: {self.get_spacing_string()}',
193
                f'orientation: {self.orientation_str}+',
194
            ]
195
        )
196
        if self._loaded:
197
            properties.append(f'dtype: {self.data.type()}')
198
            natural = humanize.naturalsize(self.memory, binary=True)
199
            properties.append(f'memory: {natural}')
200
        else:
201
            properties.append(f'path: "{self.path}"')
202
203
        properties = '; '.join(properties)
204
        string = f'{self.__class__.__name__}({properties})'
205
        return string
206
207
    def __getitem__(self, item):
208
        if isinstance(item, (slice, int, tuple)):
209
            return self._crop_from_slices(item)
210
211
        if item in (DATA, AFFINE):
212
            if item not in self:
213
                self.load()
214
        return super().__getitem__(item)
215
216
    def __array__(self):
217
        return self.data.numpy()
218
219
    def __copy__(self):
220
        kwargs = {
221
            TYPE: self.type,
222
            PATH: self.path,
223
        }
224
        if self._loaded:
225
            kwargs[TENSOR] = self.data
226
            kwargs[AFFINE] = self.affine
227
        for key, value in self.items():
228
            if key in PROTECTED_KEYS:
229
                continue
230
            kwargs[key] = value  # should I copy? deepcopy?
231
        new_image_class = type(self)
232
        new_image = new_image_class(
233
            check_nans=self.check_nans,
234
            reader=self.reader,
235
            **kwargs,
236
        )
237
        return new_image
238
239
    @property
240
    def data(self) -> torch.Tensor:
241
        """Tensor data (same as :class:`Image.tensor`)."""
242
        return self[DATA]
243
244
    @data.setter  # type: ignore[misc]
245
    @deprecated(version='0.18.16', reason=deprecation_message)
246
    def data(self, tensor: TypeData):
247
        self.set_data(tensor)
248
249
    def set_data(self, tensor: TypeData):
250
        """Store a 4D tensor in the :attr:`data` key and attribute.
251
252
        Args:
253
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
254
        """
255
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
256
        self._loaded = True
257
258
    @property
259
    def tensor(self) -> torch.Tensor:
260
        """Tensor data (same as :class:`Image.data`)."""
261
        return self.data
262
263
    @property
264
    def affine(self) -> np.ndarray:
265
        """Affine matrix to transform voxel indices into world coordinates."""
266
        # If path is a dir (probably DICOM), just load the data
267
        # Same if it's a list of paths (used to create a 4D image)
268
        # Finally, if we use a custom reader, SimpleITK probably won't be able
269
        # to read the metadata, so we resort to loading everything into memory
270
        is_custom_reader = self.reader is not read_image
271
        if self._loaded or self._is_dir() or self._is_multipath() or is_custom_reader:
272
            affine = self[AFFINE]
273
        else:
274
            assert self.path is not None
275
            assert isinstance(self.path, (str, Path))
276
            affine = read_affine(self.path)
277
        return affine
278
279
    @affine.setter
280
    def affine(self, matrix):
281
        self[AFFINE] = self._parse_affine(matrix)
282
283
    @property
284
    def type(self) -> str:  # noqa: A003
285
        return self[TYPE]
286
287
    @property
288
    def shape(self) -> TypeQuartetInt:
289
        """Tensor shape as :math:`(C, W, H, D)`."""
290
        custom_reader = self.reader is not read_image
291
        multipath = self._is_multipath()
292
        if isinstance(self.path, Path):
293
            is_dir = self.path.is_dir()
294
        shape: TypeQuartetInt
295
        if self._loaded or custom_reader or multipath or is_dir:
0 ignored issues
show
introduced by
The variable is_dir does not seem to be defined in case isinstance(self.path, Path) on line 292 is False. Are you sure this can never be the case?
Loading history...
296
            channels, si, sj, sk = self.data.shape
297
            shape = channels, si, sj, sk
298
        else:
299
            assert isinstance(self.path, (str, Path))
300
            shape = read_shape(self.path)
301
        return shape
302
303
    @property
304
    def spatial_shape(self) -> TypeTripletInt:
305
        """Tensor spatial shape as :math:`(W, H, D)`."""
306
        return self.shape[1:]
307
308
    def check_is_2d(self) -> None:
309
        if not self.is_2d():
310
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
311
            raise RuntimeError(message)
312
313
    @property
314
    def height(self) -> int:
315
        """Image height, if 2D."""
316
        self.check_is_2d()
317
        return self.spatial_shape[1]
318
319
    @property
320
    def width(self) -> int:
321
        """Image width, if 2D."""
322
        self.check_is_2d()
323
        return self.spatial_shape[0]
324
325
    @property
326
    def orientation(self) -> tuple[str, str, str]:
327
        """Orientation codes."""
328
        return nib.orientations.aff2axcodes(self.affine)
329
330
    @property
331
    def orientation_str(self) -> str:
332
        """Orientation as a string."""
333
        return ''.join(self.orientation)
334
335
    @property
336
    def direction(self) -> TypeDirection3D:
337
        _, _, direction = get_sitk_metadata_from_ras_affine(
338
            self.affine,
339
            lps=False,
340
        )
341
        return direction  # type: ignore[return-value]
342
343
    @property
344
    def spacing(self) -> tuple[float, float, float]:
345
        """Voxel spacing in mm."""
346
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
347
        sx, sy, sz = spacing
348
        return float(sx), float(sy), float(sz)
349
350
    @property
351
    def origin(self) -> tuple[float, float, float]:
352
        """Center of first voxel in array, in mm."""
353
        ox, oy, oz = self.affine[:3, 3]
354
        return ox, oy, oz
355
356
    @property
357
    def itemsize(self):
358
        """Element size of the data type."""
359
        return self.data.element_size()
360
361
    @property
362
    def memory(self) -> float:
363
        """Number of Bytes that the tensor takes in the RAM."""
364
        return np.prod(self.shape) * self.itemsize
365
366
    @property
367
    def bounds(self) -> np.ndarray:
368
        """Position of centers of voxels in smallest and largest indices."""
369
        ini = 0, 0, 0
370
        fin = np.array(self.spatial_shape) - 1
371
        point_ini = apply_affine(self.affine, ini)
372
        point_fin = apply_affine(self.affine, fin)
373
        return np.array((point_ini, point_fin))
374
375
    @property
376
    def num_channels(self) -> int:
377
        """Get the number of channels in the associated 4D tensor."""
378
        return len(self.data)
379
380
    def axis_name_to_index(self, axis: str) -> int:
381
        """Convert an axis name to an axis index.
382
383
        Args:
384
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
385
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
386
                versions and first letters are also valid, as only the first
387
                letter will be used.
388
389
        .. note:: If you are working with animals, you should probably use
390
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
391
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
392
            respectively.
393
394
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
395
            ``'Left'`` and ``'Right'``.
396
        """
397
        # Top and bottom are used for the vertical 2D axis as the use of
398
        # Height vs Horizontal might be ambiguous
399
400
        if not isinstance(axis, str):
401
            raise ValueError('Axis must be a string')
402
        axis = axis[0].upper()
403
404
        # Generally, TorchIO tensors are (C, W, H, D)
405
        if axis in 'TB':  # Top, Bottom
406
            return -2
407
        else:
408
            try:
409
                index = self.orientation.index(axis)
410
            except ValueError:
411
                index = self.orientation.index(self.flip_axis(axis))
412
            # Return negative indices so that it does not matter whether we
413
            # refer to spatial dimensions or not
414
            index = -3 + index
415
            return index
416
417
    @staticmethod
418
    def flip_axis(axis: str) -> str:
419
        """Return the opposite axis label. For example, ``'L'`` -> ``'R'``.
420
421
        Args:
422
            axis: Axis label, such as ``'L'`` or ``'left'``.
423
        """
424
        labels = 'LRPAISTBDV'
425
        first = labels[::2]
426
        last = labels[1::2]
427
        flip_dict = dict(zip(first + last, last + first))
428
        axis = axis[0].upper()
429
        flipped_axis = flip_dict.get(axis)
430
        if flipped_axis is None:
431
            values = ', '.join(labels)
432
            message = f'Axis not understood. Please use one of: {values}'
433
            raise ValueError(message)
434
        return flipped_axis
435
436
    def get_spacing_string(self) -> str:
437
        strings = [f'{n:.2f}' for n in self.spacing]
438
        string = f'({", ".join(strings)})'
439
        return string
440
441
    def get_bounds(self) -> TypeBounds:
442
        """Get minimum and maximum world coordinates occupied by the image."""
443
        first_index = 3 * (-0.5,)
444
        last_index = np.array(self.spatial_shape) - 0.5
445
        first_point = apply_affine(self.affine, first_index)
446
        last_point = apply_affine(self.affine, last_index)
447
        array = np.array((first_point, last_point))
448
        bounds_x, bounds_y, bounds_z = array.T.tolist()  # type: ignore[misc]
449
        return bounds_x, bounds_y, bounds_z  # type: ignore[return-value]
450
451
    def _parse_single_path(
452
        self,
453
        path: TypePath,
454
        *,
455
        verify: bool = True,
456
    ) -> Path:
457
        if isinstance(path, (torch.Tensor, np.ndarray)):
458
            class_name = self.__class__.__name__
459
            message = (
460
                'Expected type str or Path but found a tensor/array. Instead of'
461
                f' {class_name}(your_tensor),'
462
                f' use {class_name}(tensor=your_tensor).'
463
            )
464
            raise TypeError(message)
465
        try:
466
            path = Path(path).expanduser()
467
        except TypeError as err:
468
            message = (
469
                f'Expected type str or Path but found an object with type'
470
                f' {type(path)} instead'
471
            )
472
            raise TypeError(message) from err
473
        except RuntimeError as err:
474
            message = f'Conversion to path not possible for variable: {path}'
475
            raise RuntimeError(message) from err
476
        if not verify:
477
            return path
478
479
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
480
            raise FileNotFoundError(f'File not found: "{path}"')
481
        return path
482
483
    def _parse_path(
484
        self,
485
        path: TypePath | Sequence[TypePath] | None,
486
        *,
487
        verify: bool = True,
488
    ) -> Path | list[Path] | None:
489
        if path is None:
490
            return None
491
        elif isinstance(path, dict):
492
            # https://github.com/TorchIO-project/torchio/pull/838
493
            raise TypeError('The path argument cannot be a dictionary')
494
        elif self._is_paths_sequence(path):
495
            return [self._parse_single_path(p, verify=verify) for p in path]  # type: ignore[union-attr]
496
        else:
497
            return self._parse_single_path(path, verify=verify)  # type: ignore[arg-type]
498
499
    def _parse_tensor(
500
        self,
501
        tensor: TypeData | None,
502
        none_ok: bool = True,
503
    ) -> torch.Tensor | None:
504
        if tensor is None:
505
            if none_ok:
506
                return None
507
            else:
508
                raise RuntimeError('Input tensor cannot be None')
509
        if isinstance(tensor, np.ndarray):
510
            tensor = check_uint_to_int(tensor)
511
            tensor = torch.as_tensor(tensor)
512
        elif not isinstance(tensor, torch.Tensor):
513
            message = (
514
                'Input tensor must be a PyTorch tensor or NumPy array,'
515
                f' but type "{type(tensor)}" was found'
516
            )
517
            raise TypeError(message)
518
        ndim = tensor.ndim
519
        if ndim != 4:
520
            raise ValueError(f'Input tensor must be 4D, but it is {ndim}D')
521
        if tensor.dtype == torch.bool:
522
            tensor = tensor.to(torch.uint8)
523
        if self.check_nans and torch.isnan(tensor).any():
524
            warnings.warn('NaNs found in tensor', RuntimeWarning, stacklevel=2)
525
        return tensor
526
527
    @staticmethod
528
    def _parse_tensor_shape(tensor: torch.Tensor) -> TypeData:
529
        return ensure_4d(tensor)
530
531
    @staticmethod
532
    def _parse_affine(affine: TypeData | None) -> np.ndarray:
533
        if affine is None:
534
            return np.eye(4)
535
        if isinstance(affine, torch.Tensor):
536
            affine = affine.numpy()
537
        if not isinstance(affine, np.ndarray):
538
            bad_type = type(affine)
539
            raise TypeError(f'Affine must be a NumPy array, not {bad_type}')
540
        if affine.shape != (4, 4):
541
            bad_shape = affine.shape
542
            raise ValueError(f'Affine shape must be (4, 4), not {bad_shape}')
543
        return affine.astype(np.float64)
544
545
    @staticmethod
546
    def _is_paths_sequence(path: TypePath | Sequence[TypePath] | None) -> bool:
547
        is_not_string = not isinstance(path, str)
548
        return is_not_string and is_iterable(path)
549
550
    def _is_multipath(self) -> bool:
551
        return self._is_paths_sequence(self.path)
552
553
    def _is_dir(self) -> bool:
554
        is_sequence = self._is_multipath()
555
        if is_sequence:
556
            return False
557
        elif self.path is None:
558
            return False
559
        else:
560
            assert isinstance(self.path, Path)
561
            return self.path.is_dir()
562
563
    def load(self) -> None:
564
        r"""Load the image from disk.
565
566
        Returns:
567
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
568
            :math:`4 \times 4` affine matrix to convert voxel indices to world
569
            coordinates.
570
        """
571
        if self._loaded:
572
            return
573
574
        paths: list[Path]
575
        if self._is_multipath():
576
            paths = self.path  # type: ignore[assignment]
577
        else:
578
            paths = [self.path]  # type: ignore[list-item]
579
        tensor, affine = self.read_and_check(paths[0])
580
        tensors = [tensor]
581
        for path in paths[1:]:
582
            new_tensor, new_affine = self.read_and_check(path)
583
            if not np.array_equal(affine, new_affine):
584
                message = (
585
                    'Files have different affine matrices.'
586
                    f'\nMatrix of {paths[0]}:'
587
                    f'\n{affine}'
588
                    f'\nMatrix of {path}:'
589
                    f'\n{new_affine}'
590
                )
591
                warnings.warn(message, RuntimeWarning, stacklevel=2)
592
            if not tensor.shape[1:] == new_tensor.shape[1:]:
593
                message = (
594
                    f'Files shape do not match, found {tensor.shape}'
595
                    f'and {new_tensor.shape}'
596
                )
597
                RuntimeError(message)
598
            tensors.append(new_tensor)
599
        tensor = torch.cat(tensors)
600
        self.set_data(tensor)
601
        self.affine = affine
602
        self._loaded = True
603
604
    def unload(self) -> None:
605
        """Unload the image from memory.
606
607
        Raises:
608
            RuntimeError: If the images has not been loaded yet or if no path
609
                is available.
610
        """
611
        if not self._loaded:
612
            message = 'Image cannot be unloaded as it has not been loaded yet'
613
            raise RuntimeError(message)
614
        if self.path is None:
615
            message = (
616
                'Cannot unload image as no path is available'
617
                ' from where the image could be loaded again'
618
            )
619
            raise RuntimeError(message)
620
        self[DATA] = None
621
        self[AFFINE] = None
622
        self._loaded = False
623
624
    def read_and_check(self, path: TypePath) -> TypeDataAffine:
625
        tensor, affine = self.reader(path)
626
        # Make sure the data type is compatible with PyTorch
627
        if self.reader is not read_image and isinstance(tensor, np.ndarray):
628
            tensor = check_uint_to_int(tensor)
629
        tensor = self._parse_tensor_shape(tensor)  # type: ignore[assignment]
630
        tensor = self._parse_tensor(tensor)  # type: ignore[assignment]
631
        affine = self._parse_affine(affine)
632
        if self.check_nans and torch.isnan(tensor).any():
633
            warnings.warn(
634
                f'NaNs found in file "{path}"',
635
                RuntimeWarning,
636
                stacklevel=2,
637
            )
638
        return tensor, affine
639
640
    def save(self, path: TypePath, squeeze: bool | None = None) -> None:
641
        """Save image to disk.
642
643
        Args:
644
            path: String or instance of :class:`pathlib.Path`.
645
            squeeze: Whether to remove singleton dimensions before saving.
646
                If ``None``, the array will be squeezed if the output format is
647
                JP(E)G, PNG, BMP or TIF(F).
648
        """
649
        write_image(
650
            self.data,
651
            self.affine,
652
            path,
653
            squeeze=squeeze,
654
        )
655
656
    def is_2d(self) -> bool:
657
        return self.shape[-1] == 1
658
659
    def numpy(self) -> np.ndarray:
660
        """Get a NumPy array containing the image data."""
661
        return np.asarray(self)
662
663
    def as_sitk(self, **kwargs) -> sitk.Image:
664
        """Get the image as an instance of :class:`sitk.Image`."""
665
        return nib_to_sitk(self.data, self.affine, **kwargs)
666
667
    @classmethod
668
    def from_sitk(cls, sitk_image):
669
        """Instantiate a new TorchIO image from a :class:`sitk.Image`.
670
671
        Example:
672
            >>> import torchio as tio
673
            >>> import SimpleITK as sitk
674
            >>> sitk_image = sitk.Image(20, 30, 40, sitk.sitkUInt16)
675
            >>> tio.LabelMap.from_sitk(sitk_image)
676
            LabelMap(shape: (1, 20, 30, 40); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 93.8 KiB; dtype: torch.IntTensor)
677
            >>> sitk_image = sitk.Image((224, 224), sitk.sitkVectorFloat32, 3)
678
            >>> tio.ScalarImage.from_sitk(sitk_image)
679
            ScalarImage(shape: (3, 224, 224, 1); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 588.0 KiB; dtype: torch.FloatTensor)
680
        """
681
        tensor, affine = sitk_to_nib(sitk_image)
682
        return cls(tensor=tensor, affine=affine)
683
684
    def as_pil(self, transpose=True):
685
        """Get the image as an instance of :class:`PIL.Image`.
686
687
        .. note:: Values will be clamped to 0-255 and cast to uint8.
688
689
        .. note:: To use this method, Pillow needs to be installed:
690
            ``pip install Pillow``.
691
        """
692
        try:
693
            from PIL import Image as ImagePIL
694
        except ModuleNotFoundError as e:
695
            message = 'Please install Pillow to use Image.as_pil(): pip install Pillow'
696
            raise RuntimeError(message) from e
697
698
        self.check_is_2d()
699
        tensor = self.data
700
        if len(tensor) not in (1, 3, 4):
701
            raise NotImplementedError(
702
                'Only 1, 3 or 4 channels are supported for conversion to Pillow image'
703
            )
704
        if len(tensor) == 1:
705
            tensor = torch.cat(3 * [tensor])
706
        if transpose:
707
            tensor = tensor.permute(3, 2, 1, 0)
708
        else:
709
            tensor = tensor.permute(3, 1, 2, 0)
710
        array = tensor.clamp(0, 255).numpy()[0]
711
        return ImagePIL.fromarray(array.astype(np.uint8))
712
713
    def to_gif(
714
        self,
715
        axis: int,
716
        duration: float,  # of full gif
717
        output_path: TypePath,
718
        loop: int = 0,
719
        rescale: bool = True,
720
        optimize: bool = True,
721
        reverse: bool = False,
722
    ) -> None:
723
        """Save an animated GIF of the image.
724
725
        Args:
726
            axis: Spatial axis (0, 1 or 2).
727
            duration: Duration of the full animation in seconds.
728
            output_path: Path to the output GIF file.
729
            loop: Number of times the GIF should loop.
730
                ``0`` means that it will loop forever.
731
            rescale: Use :class:`~torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity`
732
                to rescale the intensity values to :math:`[0, 255]`.
733
            optimize: If ``True``, attempt to compress the palette by
734
                eliminating unused colors. This is only useful if the palette
735
                can be compressed to the next smaller power of 2 elements.
736
            reverse: Reverse the temporal order of frames.
737
        """
738
        from ..visualization import make_gif  # avoid circular import
739
740
        make_gif(
741
            self.data,
742
            axis,
743
            duration,
744
            output_path,
745
            loop=loop,
746
            rescale=rescale,
747
            optimize=optimize,
748
            reverse=reverse,
749
        )
750
751
    def to_ras(self) -> Image:
752
        if self.orientation_str != 'RAS':
753
            from ..transforms.preprocessing.spatial.to_canonical import ToCanonical
754
755
            return ToCanonical()(self)
756
        return self
757
758
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
759
        """Get image center in RAS+ or LPS+ coordinates.
760
761
        Args:
762
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
763
                the first dimension grows towards the left, etc. Otherwise, the
764
                coordinates will be in RAS+ orientation.
765
        """
766
        size = np.array(self.spatial_shape)
767
        center_index = (size - 1) / 2
768
        r, a, s = apply_affine(self.affine, center_index)
769
        if lps:
770
            return (-r, -a, s)
771
        else:
772
            return (r, a, s)
773
774
    def set_check_nans(self, check_nans: bool) -> None:
775
        self.check_nans = check_nans
776
777
    def plot(self, **kwargs) -> None:
778
        """Plot image."""
779
        if self.is_2d():
780
            self.as_pil().show()
781
        else:
782
            from ..visualization import plot_volume  # avoid circular import
783
784
            plot_volume(self, **kwargs)
785
786
    def show(self, viewer_path: TypePath | None = None) -> None:
787
        """Open the image using external software.
788
789
        Args:
790
            viewer_path: Path to the application used to view the image. If
791
                ``None``, the value of the environment variable
792
                ``SITK_SHOW_COMMAND`` will be used. If this variable is also
793
                not set, TorchIO will try to guess the location of
794
                `ITK-SNAP <http://www.itksnap.org/pmwiki/pmwiki.php>`_ and
795
                `3D Slicer <https://www.slicer.org/>`_.
796
797
        Raises:
798
            RuntimeError: If the viewer is not found.
799
        """
800
        sitk_image = self.as_sitk()
801
        image_viewer = sitk.ImageViewer()
802
        # This is so that 3D Slicer creates segmentation nodes from label maps
803
        if self.__class__.__name__ == 'LabelMap':
804
            image_viewer.SetFileExtension('.seg.nrrd')
805
        if viewer_path is not None:
806
            image_viewer.SetApplication(str(viewer_path))
807
        try:
808
            image_viewer.Execute(sitk_image)
809
        except RuntimeError as e:
810
            viewer_path = guess_external_viewer()
811
            if viewer_path is None:
812
                message = (
813
                    'No external viewer has been found. Please set the'
814
                    ' environment variable SITK_SHOW_COMMAND to a viewer of'
815
                    ' your choice'
816
                )
817
                raise RuntimeError(message) from e
818
            image_viewer.SetApplication(str(viewer_path))
819
            image_viewer.Execute(sitk_image)
820
821
    def _crop_from_slices(
822
        self,
823
        slices: TypeSlice | tuple[TypeSlice, ...],
824
    ) -> Image:
825
        from ..transforms import Crop
826
827
        slices_tuple = to_tuple(slices)  # type: ignore[assignment]
828
        cropping: list[int] = []
829
        for dim, slice_ in enumerate(slices_tuple):
830
            if isinstance(slice_, slice):
831
                pass
832
            elif slice_ is Ellipsis:
833
                message = 'Ellipsis slicing is not supported yet'
834
                raise NotImplementedError(message)
835
            elif isinstance(slice_, int):
836
                slice_ = slice(slice_, slice_ + 1)  # type: ignore[assignment]
837
            else:
838
                message = f'Slice type not understood: "{type(slice_)}"'
839
                raise TypeError(message)
840
            shape_dim = self.spatial_shape[dim]
841
            assert isinstance(slice_, slice)
842
            start, stop, step = slice_.indices(shape_dim)
843
            if step != 1:
844
                message = (
845
                    'Slicing with steps different from 1 is not supported yet.'
846
                    ' Use the Crop transform instead'
847
                )
848
                raise ValueError(message)
849
            crop_ini = start
850
            crop_fin = shape_dim - stop
851
            cropping.extend([crop_ini, crop_fin])
852
        while dim < 2:
0 ignored issues
show
introduced by
The variable dim does not seem to be defined in case the for loop on line 829 is not entered. Are you sure this can never be the case?
Loading history...
853
            cropping.extend([0, 0])
854
            dim += 1
855
        w_ini, w_fin, h_ini, h_fin, d_ini, d_fin = cropping
856
        cropping_arg = w_ini, w_fin, h_ini, h_fin, d_ini, d_fin  # making mypy happy
857
        return Crop(cropping_arg)(self)  # type: ignore[return-value]
858
859
860
class ScalarImage(Image):
861
    """Image whose pixel values represent scalars.
862
863
    Example:
864
        >>> import torch
865
        >>> import torchio as tio
866
        >>> # Loading from a file
867
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
868
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
869
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
870
        >>> data, affine = image.data, image.affine
871
        >>> affine.shape
872
        (4, 4)
873
        >>> image.data is image[tio.DATA]
874
        True
875
        >>> image.data is image.tensor
876
        True
877
        >>> type(image.data)
878
        torch.Tensor
879
880
    See :class:`~torchio.Image` for more information.
881
    """
882
883
    def __init__(self, *args, **kwargs):
884
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
885
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
886
        kwargs.update({'type': INTENSITY})
887
        super().__init__(*args, **kwargs)
888
889
    def hist(self, **kwargs) -> None:
890
        """Plot histogram."""
891
        from ..visualization import plot_histogram
892
893
        x = self.data.flatten().numpy()
894
        plot_histogram(x, **kwargs)
895
896
    def to_video(
897
        self,
898
        output_path: TypePath,
899
        frame_rate: float | None = 15,
900
        seconds: float | None = None,
901
        direction: str = 'I',
902
        verbosity: str = 'error',
903
    ) -> None:
904
        """Create a video showing all image slices along a specified direction.
905
906
        Args:
907
            output_path: Path to the output video file.
908
            frame_rate: Number of frames per second (FPS).
909
            seconds: Target duration of the full video.
910
            direction:
911
            verbosity:
912
913
        .. note:: Only ``frame_rate`` or ``seconds`` may (and must) be specified.
914
        """
915
        from ..visualization import make_video  # avoid circular import
916
917
        make_video(
918
            self.to_ras(),  # type: ignore[arg-type]
919
            output_path,
920
            frame_rate=frame_rate,
921
            seconds=seconds,
922
            direction=direction,
923
            verbosity=verbosity,
924
        )
925
926
927
class LabelMap(Image):
928
    """Image whose pixel values represent segmentation labels.
929
930
    A sequence of paths to 3D images can be passed to create a 4D image.
931
    This is useful to create a
932
    `tissue probability map (TPM) <https://andysbrainbook.readthedocs.io/en/latest/SPM/SPM_Short_Course/SPM_04_Preprocessing/04_SPM_Segmentation.html#tissue-probability-maps>`,
933
    which contains the probability of each voxel belonging to a certain tissue type,
934
    or to create a label map with overlapping labels.
935
936
    Intensity transforms are not applied to these images.
937
938
    Nearest neighbor interpolation is always used to resample label maps,
939
    independently of the specified interpolation type in the transform
940
    instantiation.
941
942
    Example:
943
        >>> import torch
944
        >>> import torchio as tio
945
        >>> binary_tensor = torch.rand(1, 128, 128, 68) > 0.5
946
        >>> label_map = tio.LabelMap(tensor=binary_tensor)  # from a tensor
947
        >>> label_map = tio.LabelMap('t1_seg.nii.gz')  # from a file
948
        >>> # Create a 4D tissue probability map from different 3D images
949
        >>> tissues = 'gray_matter.nii.gz', 'white_matter.nii.gz', 'csf.nii.gz'
950
        >>> tpm = tio.LabelMap(tissues)
951
952
    See :class:`~torchio.Image` for more information.
953
    """
954
955
    def __init__(self, *args, **kwargs):
956
        if 'type' in kwargs and kwargs['type'] != LABEL:
957
            raise ValueError('Type of LabelMap is always torchio.LABEL')
958
        kwargs.update({'type': LABEL})
959
        super().__init__(*args, **kwargs)
960
961
    def count_nonzero(self) -> int:
962
        """Get the number of voxels that are not 0."""
963
        return int(self.data.count_nonzero().item())
964
965
    def count_labels(self) -> dict[int, int]:
966
        """Get the number of voxels in each label."""
967
        values_list = self.data.flatten().tolist()
968
        counter = Counter(values_list)
969
        counts = {label: counter[label] for label in sorted(counter)}
970
        return counts
971