torchio.data.image.Image.memory()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from collections.abc import Iterable
4
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List, Callable
5
6
import torch
7
import humanize
8
import numpy as np
9
import nibabel as nib
10
import SimpleITK as sitk
11
from deprecated import deprecated
12
13
from ..utils import get_stem
14
from ..typing import (
15
    TypeData,
16
    TypePath,
17
    TypeTripletInt,
18
    TypeTripletFloat,
19
    TypeDirection3D,
20
)
21
from ..constants import DATA, TYPE, AFFINE, PATH, STEM, INTENSITY, LABEL
22
from .io import (
23
    ensure_4d,
24
    read_image,
25
    write_image,
26
    nib_to_sitk,
27
    sitk_to_nib,
28
    check_uint_to_int,
29
    get_rotation_and_spacing_from_affine,
30
    get_sitk_metadata_from_ras_affine,
31
    read_shape,
32
    read_affine,
33
)
34
35
36
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
37
TypeBound = Tuple[float, float]
38
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
39
40
deprecation_message = (
41
    'Setting the image data with the property setter is deprecated. Use the'
42
    ' set_data() method instead'
43
)
44
45
46
class Image(dict):
47
    r"""TorchIO image.
48
49
    For information about medical image orientation, check out `NiBabel docs`_,
50
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
51
    `SimpleITK docs`_.
52
53
    Args:
54
        path: Path to a file or sequence of paths to files that can be read by
55
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
56
            DICOM files. If :attr:`tensor` is given, the data in
57
            :attr:`path` will not be read.
58
            If a sequence of paths is given, data
59
            will be concatenated on the channel dimension so spatial
60
            dimensions must match.
61
        type: Type of image, such as :attr:`torchio.INTENSITY` or
62
            :attr:`torchio.LABEL`. This will be used by the transforms to
63
            decide whether to apply an operation, or which interpolation to use
64
            when resampling. For example, `preprocessing`_ and `augmentation`_
65
            intensity transforms will only be applied to images with type
66
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
67
            all types, and nearest neighbor interpolation is always used to
68
            resample images with type :attr:`torchio.LABEL`.
69
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
70
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
71
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
72
            :class:`torch.Tensor` or NumPy array with dimensions
73
            :math:`(C, W, H, D)`.
74
        affine: :math:`4 \times 4` matrix to convert voxel coordinates to world
75
            coordinates. If ``None``, an identity matrix will be used. See the
76
            `NiBabel docs on coordinates`_ for more information.
77
        check_nans: If ``True``, issues a warning if NaNs are found
78
            in the image. If ``False``, images will not be checked for the
79
            presence of NaNs.
80
        reader: Callable object that takes a path and returns a 4D tensor and a
81
            2D, :math:`4 \times 4` affine matrix. This can be used if your data
82
            is saved in a custom format, such as ``.npy`` (see example below).
83
            If the affine matrix is ``None``, an identity matrix will be used.
84
        **kwargs: Items that will be added to the image dictionary, e.g.
85
            acquisition parameters.
86
87
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
88
    when needed.
89
90
    Example:
91
        >>> import torchio as tio
92
        >>> import numpy as np
93
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
94
        >>> image  # not loaded yet
95
        ScalarImage(path: t1.nii.gz; type: intensity)
96
        >>> times_two = 2 * image.data  # data is loaded and cached here
97
        >>> image
98
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
99
        >>> image.save('doubled_image.nii.gz')
100
        >>> numpy_reader = lambda path: np.load(path), np.eye(4)
101
        >>> image = tio.ScalarImage('t1.npy', reader=numpy_reader)
102
103
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
104
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
105
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
106
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
107
    .. _NiBabel docs on coordinates: https://nipy.org/nibabel/coordinate_systems.html#the-affine-matrix-as-a-transformation-between-spaces
108
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
109
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
110
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
111
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
112
    """
113
    def __init__(
114
            self,
115
            path: Union[TypePath, Sequence[TypePath], None] = None,
116
            type: str = None,
117
            tensor: Optional[TypeData] = None,
118
            affine: Optional[TypeData] = None,
119
            check_nans: bool = False,  # removed by ITK by default
120
            reader: Callable = read_image,
121
            **kwargs: Dict[str, Any],
122
            ):
123
        self.check_nans = check_nans
124
        self.reader = reader
125
126
        if type is None:
127
            warnings.warn(
128
                'Not specifying the image type is deprecated and will be'
129
                ' mandatory in the future. You can probably use tio.ScalarImage'
130
                ' or tio.LabelMap instead',
131
            )
132
            type = INTENSITY
133
134
        if path is None and tensor is None:
135
            raise ValueError('A value for path or tensor must be given')
136
        self._loaded = False
137
138
        tensor = self._parse_tensor(tensor)
139
        affine = self._parse_affine(affine)
140
        if tensor is not None:
141
            self.set_data(tensor)
142
            self.affine = affine
143
            self._loaded = True
144
        for key in PROTECTED_KEYS:
145
            if key in kwargs:
146
                message = f'Key "{key}" is reserved. Use a different one'
147
                raise ValueError(message)
148
        if 'channels_last' in kwargs:
149
            message = (
150
                'The "channels_last" keyword argument is deprecated after'
151
                ' https://github.com/fepegar/torchio/pull/685 and will be'
152
                ' removed in the future'
153
            )
154
            warnings.warn(message, DeprecationWarning)
155
156
        super().__init__(**kwargs)
157
        self.path = self._parse_path(path)
158
159
        self[PATH] = '' if self.path is None else str(self.path)
160
        self[STEM] = '' if self.path is None else get_stem(self.path)
161
        self[TYPE] = type
162
163
    def __repr__(self):
164
        properties = []
165
        properties.extend([
166
            f'shape: {self.shape}',
167
            f'spacing: {self.get_spacing_string()}',
168
            f'orientation: {"".join(self.orientation)}+',
169
        ])
170
        if self._loaded:
171
            properties.append(f'dtype: {self.data.type()}')
172
            properties.append(f'memory: {humanize.naturalsize(self.memory, binary=True)}')
173
        else:
174
            properties.append(f'path: "{self.path}"')
175
176
        properties = '; '.join(properties)
177
        string = f'{self.__class__.__name__}({properties})'
178
        return string
179
180
    def __getitem__(self, item):
181
        if item in (DATA, AFFINE):
182
            if item not in self:
183
                self.load()
184
        return super().__getitem__(item)
185
186
    def __array__(self):
187
        return self.data.numpy()
188
189
    def __copy__(self):
190
        kwargs = dict(
191
            tensor=self.data,
192
            affine=self.affine,
193
            type=self.type,
194
            path=self.path,
195
        )
196
        for key, value in self.items():
197
            if key in PROTECTED_KEYS: continue
198
            kwargs[key] = value  # should I copy? deepcopy?
199
        return self.__class__(**kwargs)
200
201
    @property
202
    def data(self) -> torch.Tensor:
203
        """Tensor data. Same as :class:`Image.tensor`."""
204
        return self[DATA]
205
206
    @data.setter  # type: ignore
207
    @deprecated(version='0.18.16', reason=deprecation_message)
208
    def data(self, tensor: TypeData):
209
        self.set_data(tensor)
210
211
    def set_data(self, tensor: TypeData):
212
        """Store a 4D tensor in the :attr:`data` key and attribute.
213
214
        Args:
215
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
216
        """
217
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
218
219
    @property
220
    def tensor(self) -> torch.Tensor:
221
        """Tensor data. Same as :class:`Image.data`."""
222
        return self.data
223
224
    @property
225
    def affine(self) -> np.ndarray:
226
        """Affine matrix to transform voxel indices into world coordinates."""
227
        # If path is a dir (probably DICOM), just load the data
228
        # Same if it's a list of paths (used to create a 4D image)
229
        if self._loaded or (isinstance(self.path, Path) and self.path.is_dir()):
230
            affine = self[AFFINE]
231
        else:
232
            affine = read_affine(self.path)
233
        return affine
234
235
    @affine.setter
236
    def affine(self, matrix):
237
        self[AFFINE] = self._parse_affine(matrix)
238
239
    @property
240
    def type(self) -> str:
241
        return self[TYPE]
242
243
    @property
244
    def shape(self) -> Tuple[int, int, int, int]:
245
        """Tensor shape as :math:`(C, W, H, D)`."""
246
        custom_reader = self.reader is not read_image
247
        multipath = not isinstance(self.path, (str, Path))
248
        if self._loaded or custom_reader or multipath or self.path.is_dir():
249
            shape = tuple(self.data.shape)
250
        else:
251
            shape = read_shape(self.path)
252
        return shape
253
254
    @property
255
    def spatial_shape(self) -> TypeTripletInt:
256
        """Tensor spatial shape as :math:`(W, H, D)`."""
257
        return self.shape[1:]
258
259
    def check_is_2d(self) -> None:
260
        if not self.is_2d():
261
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
262
            raise RuntimeError(message)
263
264
    @property
265
    def height(self) -> int:
266
        """Image height, if 2D."""
267
        self.check_is_2d()
268
        return self.spatial_shape[1]
269
270
    @property
271
    def width(self) -> int:
272
        """Image width, if 2D."""
273
        self.check_is_2d()
274
        return self.spatial_shape[0]
275
276
    @property
277
    def orientation(self) -> Tuple[str, str, str]:
278
        """Orientation codes."""
279
        return nib.aff2axcodes(self.affine)
280
281
    @property
282
    def direction(self) -> TypeDirection3D:
283
        _, _, direction = get_sitk_metadata_from_ras_affine(
284
            self.affine, lps=False)
285
        return direction
286
287
    @property
288
    def spacing(self) -> Tuple[float, float, float]:
289
        """Voxel spacing in mm."""
290
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
291
        return tuple(spacing)
292
293
    @property
294
    def origin(self) -> Tuple[float, float, float]:
295
        """Center of first voxel in array, in mm."""
296
        return tuple(self.affine[:3, 3])
297
298
    @property
299
    def itemsize(self):
300
        """Element size of the data type."""
301
        return self.data.element_size()
302
303
    @property
304
    def memory(self) -> float:
305
        """Number of Bytes that the tensor takes in the RAM."""
306
        return np.prod(self.shape) * self.itemsize
307
308
    @property
309
    def bounds(self) -> np.ndarray:
310
        """Position of centers of voxels in smallest and largest coordinates."""
311
        ini = 0, 0, 0
312
        fin = np.array(self.spatial_shape) - 1
313
        point_ini = nib.affines.apply_affine(self.affine, ini)
314
        point_fin = nib.affines.apply_affine(self.affine, fin)
315
        return np.array((point_ini, point_fin))
316
317
    @property
318
    def num_channels(self) -> int:
319
        """Get the number of channels in the associated 4D tensor."""
320
        return len(self.data)
321
322
    def axis_name_to_index(self, axis: str) -> int:
323
        """Convert an axis name to an axis index.
324
325
        Args:
326
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
327
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
328
                versions and first letters are also valid, as only the first
329
                letter will be used.
330
331
        .. note:: If you are working with animals, you should probably use
332
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
333
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
334
            respectively.
335
336
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
337
            ``'Left'`` and ``'Right'``.
338
        """
339
        # Top and bottom are used for the vertical 2D axis as the use of
340
        # Height vs Horizontal might be ambiguous
341
342
        if not isinstance(axis, str):
343
            raise ValueError('Axis must be a string')
344
        axis = axis[0].upper()
345
346
        # Generally, TorchIO tensors are (C, W, H, D)
347
        if axis in 'TB':  # Top, Bottom
348
            return -2
349
        else:
350
            try:
351
                index = self.orientation.index(axis)
352
            except ValueError:
353
                index = self.orientation.index(self.flip_axis(axis))
354
            # Return negative indices so that it does not matter whether we
355
            # refer to spatial dimensions or not
356
            index = -3 + index
357
            return index
358
359
    # flake8: noqa: E701
360
    @staticmethod
361
    def flip_axis(axis: str) -> str:
362
        if axis == 'R': flipped_axis = 'L'
363
        elif axis == 'L': flipped_axis = 'R'
364
        elif axis == 'A': flipped_axis = 'P'
365
        elif axis == 'P': flipped_axis = 'A'
366
        elif axis == 'I': flipped_axis = 'S'
367
        elif axis == 'S': flipped_axis = 'I'
368
        elif axis == 'T': flipped_axis = 'B'  # top / bottom
369
        elif axis == 'B': flipped_axis = 'T'
370
        else:
371
            values = ', '.join('LRPAISTB')
372
            message = f'Axis not understood. Please use one of: {values}'
373
            raise ValueError(message)
374
        return flipped_axis
375
376
    def get_spacing_string(self) -> str:
377
        strings = [f'{n:.2f}' for n in self.spacing]
378
        string = f'({", ".join(strings)})'
379
        return string
380
381
    def get_bounds(self) -> TypeBounds:
382
        """Get minimum and maximum world coordinates occupied by the image."""
383
        first_index = 3 * (-0.5,)
384
        last_index = np.array(self.spatial_shape) - 0.5
385
        first_point = nib.affines.apply_affine(self.affine, first_index)
386
        last_point = nib.affines.apply_affine(self.affine, last_index)
387
        array = np.array((first_point, last_point))
388
        bounds_x, bounds_y, bounds_z = array.T.tolist()
389
        return bounds_x, bounds_y, bounds_z
390
391
    @staticmethod
392
    def _parse_single_path(
393
            path: TypePath
394
            ) -> Path:
395
        try:
396
            path = Path(path).expanduser()
397
        except TypeError:
398
            message = (
399
                f'Expected type str or Path but found {path} with type'
400
                f' {type(path)} instead'
401
            )
402
            raise TypeError(message)
403
        except RuntimeError:
404
            message = (
405
                f'Conversion to path not possible for variable: {path}'
406
            )
407
            raise RuntimeError(message)
408
409
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
410
            raise FileNotFoundError(f'File not found: "{path}"')
411
        return path
412
413
    def _parse_path(
414
            self,
415
            path: Union[TypePath, Sequence[TypePath], None]
416
            ) -> Optional[Union[Path, List[Path]]]:
417
        if path is None:
418
            return None
419
        if isinstance(path, Iterable) and not isinstance(path, str):
420
            return [self._parse_single_path(p) for p in path]
421
        else:
422
            return self._parse_single_path(path)
423
424
    def _parse_tensor(
425
            self,
426
            tensor: Optional[TypeData],
427
            none_ok: bool = True,
428
            ) -> Optional[torch.Tensor]:
429
        if tensor is None:
430
            if none_ok:
431
                return None
432
            else:
433
                raise RuntimeError('Input tensor cannot be None')
434
        if isinstance(tensor, np.ndarray):
435
            tensor = check_uint_to_int(tensor)
436
            tensor = torch.as_tensor(tensor)
437
        elif not isinstance(tensor, torch.Tensor):
438
            message = (
439
                'Input tensor must be a PyTorch tensor or NumPy array,'
440
                f' but type "{type(tensor)}" was found'
441
            )
442
            raise TypeError(message)
443
        ndim = tensor.ndim
444
        if ndim != 4:
445
            raise ValueError(f'Input tensor must be 4D, but it is {ndim}D')
446
        if tensor.dtype == torch.bool:
447
            tensor = tensor.to(torch.uint8)
448
        if self.check_nans and torch.isnan(tensor).any():
449
            warnings.warn(f'NaNs found in tensor', RuntimeWarning)
450
        return tensor
451
452
    @staticmethod
453
    def _parse_tensor_shape(tensor: torch.Tensor) -> TypeData:
454
        return ensure_4d(tensor)
455
456
    @staticmethod
457
    def _parse_affine(affine: Optional[TypeData]) -> np.ndarray:
458
        if affine is None:
459
            return np.eye(4)
460
        if isinstance(affine, torch.Tensor):
461
            affine = affine.numpy()
462
        if not isinstance(affine, np.ndarray):
463
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
464
        if affine.shape != (4, 4):
465
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
466
        return affine.astype(np.float64)
467
468
    def load(self) -> None:
469
        r"""Load the image from disk.
470
471
        Returns:
472
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
473
            :math:`4 \times 4` affine matrix to convert voxel indices to world
474
            coordinates.
475
        """
476
        if self._loaded:
477
            return
478
        paths = self.path if isinstance(self.path, list) else [self.path]
479
        tensor, affine = self.read_and_check(paths[0])
480
        tensors = [tensor]
481
        for path in paths[1:]:
482
            new_tensor, new_affine = self.read_and_check(path)
483
            if not np.array_equal(affine, new_affine):
484
                message = (
485
                    'Files have different affine matrices.'
486
                    f'\nMatrix of {paths[0]}:'
487
                    f'\n{affine}'
488
                    f'\nMatrix of {path}:'
489
                    f'\n{new_affine}'
490
                )
491
                warnings.warn(message, RuntimeWarning)
492
            if not tensor.shape[1:] == new_tensor.shape[1:]:
493
                message = (
494
                    f'Files shape do not match, found {tensor.shape}'
495
                    f'and {new_tensor.shape}'
496
                )
497
                RuntimeError(message)
498
            tensors.append(new_tensor)
499
        tensor = torch.cat(tensors)
500
        self.set_data(tensor)
501
        self.affine = affine
502
        self._loaded = True
503
504
    def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
505
        tensor, affine = self.reader(path)
506
        tensor = self._parse_tensor_shape(tensor)
507
        tensor = self._parse_tensor(tensor)
508
        affine = self._parse_affine(affine)
509
        if self.check_nans and torch.isnan(tensor).any():
510
            warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
511
        return tensor, affine
512
513
    def save(self, path: TypePath, squeeze: Optional[bool] = None) -> None:
514
        """Save image to disk.
515
516
        Args:
517
            path: String or instance of :class:`pathlib.Path`.
518
            squeeze: Whether to remove singleton dimensions before saving.
519
                If ``None``, the array will be squeezed if the output format is
520
                JP(E)G, PNG, BMP or TIF(F).
521
        """
522
        write_image(
523
            self.data,
524
            self.affine,
525
            path,
526
            squeeze=squeeze,
527
        )
528
529
    def is_2d(self) -> bool:
530
        return self.shape[-1] == 1
531
532
    def numpy(self) -> np.ndarray:
533
        """Get a NumPy array containing the image data."""
534
        return np.asarray(self)
535
536
    def as_sitk(self, **kwargs) -> sitk.Image:
537
        """Get the image as an instance of :class:`sitk.Image`."""
538
        return nib_to_sitk(self.data, self.affine, **kwargs)
539
540
    @classmethod
541
    def from_sitk(cls, sitk_image):
542
        """Instantiate a new TorchIO image from a :class:`sitk.Image`.
543
544
        Example:
545
            >>> import torchio as tio
546
            >>> import SimpleITK as sitk
547
            >>> sitk_image = sitk.Image(20, 30, 40, sitk.sitkUInt16)
548
            >>> tio.LabelMap.from_sitk(sitk_image)
549
            LabelMap(shape: (1, 20, 30, 40); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 93.8 KiB; dtype: torch.IntTensor)
550
            >>> sitk_image = sitk.Image((224, 224), sitk.sitkVectorFloat32, 3)
551
            >>> tio.ScalarImage.from_sitk(sitk_image)
552
            ScalarImage(shape: (3, 224, 224, 1); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 588.0 KiB; dtype: torch.FloatTensor)
553
        """
554
        tensor, affine = sitk_to_nib(sitk_image)
555
        return cls(tensor=tensor, affine=affine)
556
557
    def as_pil(self, transpose=True):
558
        """Get the image as an instance of :class:`PIL.Image`.
559
560
        .. note:: Values will be clamped to 0-255 and cast to uint8.
561
        .. note:: To use this method, `Pillow` needs to be installed:
562
            `pip install Pillow`.
563
        """
564
        try:
565
            from PIL import Image as ImagePIL
566
        except ModuleNotFoundError as e:
567
            message = (
568
                'Please install Pillow to use Image.as_pil():'
569
                ' pip install Pillow'
570
            )
571
            raise RuntimeError(message) from e
572
573
        self.check_is_2d()
574
        tensor = self.data
575
        if len(tensor) == 1:
576
            tensor = torch.cat(3 * [tensor])
577
        if len(tensor) != 3:
578
            raise RuntimeError('The image must have 1 or 3 channels')
579
        if transpose:
580
            tensor = tensor.permute(3, 2, 1, 0)
581
        else:
582
            tensor = tensor.permute(3, 1, 2, 0)
583
        array = tensor.clamp(0, 255).numpy()[0]
584
        return ImagePIL.fromarray(array.astype(np.uint8))
585
586
    def to_gif(
587
            self,
588
            axis: int,
589
            duration: float,  # of full gif
590
            output_path: TypePath,
591
            loop: int = 0,
592
            rescale: bool = True,
593
            optimize: bool = True,
594
            reverse: bool = False,
595
        ) -> None:
596
        """Save an animated GIF of the image.
597
598
        Args:
599
            axis: Spatial axis (0, 1 or 2).
600
            duration: Duration of the full animation in seconds.
601
            output_path: Path to the output GIF file.
602
            loop: Number of times the GIF should loop.
603
                ``0`` means that it will loop forever.
604
            rescale: Use :class:`~torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity`
605
                to rescale the intensity values to :math:`[0, 255]`.
606
            optimize: If ``True``, attempt to compress the palette by
607
                eliminating unused colors. This is only useful if the palette
608
                can be compressed to the next smaller power of 2 elements.
609
            reverse: Reverse the temporal order of frames.
610
        """  # noqa: E501
611
        from ..visualization import make_gif  # avoid circular import
612
        make_gif(
613
            self.data,
614
            axis,
615
            duration,
616
            output_path,
617
            loop=loop,
618
            rescale=rescale,
619
            optimize=optimize,
620
            reverse=reverse,
621
        )
622
623
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
624
        """Get image center in RAS+ or LPS+ coordinates.
625
626
        Args:
627
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
628
                the first dimension grows towards the left, etc. Otherwise, the
629
                coordinates will be in RAS+ orientation.
630
        """
631
        size = np.array(self.spatial_shape)
632
        center_index = (size - 1) / 2
633
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
634
        if lps:
635
            return (-r, -a, s)
636
        else:
637
            return (r, a, s)
638
639
    def set_check_nans(self, check_nans: bool) -> None:
640
        self.check_nans = check_nans
641
642
    def plot(self, **kwargs) -> None:
643
        """Plot image."""
644
        if self.is_2d():
645
            self.as_pil().show()
646
        else:
647
            from ..visualization import plot_volume  # avoid circular import
648
            plot_volume(self, **kwargs)
649
650
651
class ScalarImage(Image):
652
    """Image whose pixel values represent scalars.
653
654
    Example:
655
        >>> import torch
656
        >>> import torchio as tio
657
        >>> # Loading from a file
658
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
659
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
660
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
661
        >>> data, affine = image.data, image.affine
662
        >>> affine.shape
663
        (4, 4)
664
        >>> image.data is image[tio.DATA]
665
        True
666
        >>> image.data is image.tensor
667
        True
668
        >>> type(image.data)
669
        torch.Tensor
670
671
    See :class:`~torchio.Image` for more information.
672
    """
673
    def __init__(self, *args, **kwargs):
674
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
675
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
676
        kwargs.update({'type': INTENSITY})
677
        super().__init__(*args, **kwargs)
678
679
680
class LabelMap(Image):
681
    """Image whose pixel values represent categorical labels.
682
683
    Example:
684
        >>> import torch
685
        >>> import torchio as tio
686
        >>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5)
687
        >>> labels = tio.LabelMap('t1_seg.nii.gz')  # loading from a file
688
        >>> tpm = tio.LabelMap(                     # loading from files
689
        ...     'gray_matter.nii.gz',
690
        ...     'white_matter.nii.gz',
691
        ...     'csf.nii.gz',
692
        ... )
693
694
    Intensity transforms are not applied to these images.
695
696
    Nearest neighbor interpolation is always used to resample label maps,
697
    independently of the specified interpolation type in the transform
698
    instantiation.
699
700
    See :class:`~torchio.Image` for more information.
701
    """
702
    def __init__(self, *args, **kwargs):
703
        if 'type' in kwargs and kwargs['type'] != LABEL:
704
            raise ValueError('Type of LabelMap is always torchio.LABEL')
705
        kwargs.update({'type': LABEL})
706
        super().__init__(*args, **kwargs)
707