Passed
Pull Request — main (#1404)
by Fernando
01:49
created

torchio.data.image.Image._repr_html_()   A

Complexity

Conditions 2

Size

Total Lines 21
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 18
nop 1
dl 0
loc 21
rs 9.5
c 0
b 0
f 0
1
from __future__ import annotations
2
3
import base64
4
import io
5
import warnings
6
from collections import Counter
7
from collections.abc import Callable
8
from collections.abc import Sequence
9
from pathlib import Path
10
from typing import TYPE_CHECKING
11
from typing import Any
12
13
import humanize
14
import nibabel as nib
15
import numpy as np
16
import SimpleITK as sitk
17
import torch
18
from deprecated import deprecated
19
from nibabel.affines import apply_affine
20
21
from ..constants import AFFINE
22
from ..constants import DATA
23
from ..constants import INTENSITY
24
from ..constants import LABEL
25
from ..constants import PATH
26
from ..constants import STEM
27
from ..constants import TENSOR
28
from ..constants import TYPE
29
from ..types import TypeData
30
from ..types import TypeDataAffine
31
from ..types import TypeDirection3D
32
from ..types import TypePath
33
from ..types import TypeQuartetInt
34
from ..types import TypeSlice
35
from ..types import TypeTripletFloat
36
from ..types import TypeTripletInt
37
from ..utils import get_stem
38
from ..utils import guess_external_viewer
39
from ..utils import is_iterable
40
from ..utils import to_tuple
41
from .io import check_uint_to_int
42
from .io import ensure_4d
43
from .io import get_rotation_and_spacing_from_affine
44
from .io import get_sitk_metadata_from_ras_affine
45
from .io import nib_to_sitk
46
from .io import read_affine
47
from .io import read_image
48
from .io import read_shape
49
from .io import sitk_to_nib
50
from .io import write_image
51
52
if TYPE_CHECKING:
53
    from matplotlib.figure import Figure
54
55
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
56
TypeBound = tuple[float, float]
57
TypeBounds = tuple[TypeBound, TypeBound, TypeBound]
58
59
deprecation_message = (
60
    'Setting the image data with the property setter is deprecated. Use the'
61
    ' set_data() method instead'
62
)
63
64
65
class Image(dict):
66
    r"""TorchIO image.
67
68
    For information about medical image orientation, check out `NiBabel docs`_,
69
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
70
    `SimpleITK docs`_.
71
72
    Args:
73
        path: Path to a file or sequence of paths to files that can be read by
74
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
75
            DICOM files. If :attr:`tensor` is given, the data in
76
            :attr:`path` will not be read.
77
            If a sequence of paths is given, data
78
            will be concatenated on the channel dimension so spatial
79
            dimensions must match.
80
        type: Type of image, such as :attr:`torchio.INTENSITY` or
81
            :attr:`torchio.LABEL`. This will be used by the transforms to
82
            decide whether to apply an operation, or which interpolation to use
83
            when resampling. For example, `preprocessing`_ and `augmentation`_
84
            intensity transforms will only be applied to images with type
85
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
86
            all types, and nearest neighbor interpolation is always used to
87
            resample images with type :attr:`torchio.LABEL`.
88
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
89
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
90
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
91
            :class:`torch.Tensor` or NumPy array with dimensions
92
            :math:`(C, W, H, D)`.
93
        affine: :math:`4 \times 4` matrix to convert voxel coordinates to world
94
            coordinates. If ``None``, an identity matrix will be used. See the
95
            `NiBabel docs on coordinates`_ for more information.
96
        check_nans: If ``True``, issues a warning if NaNs are found
97
            in the image. If ``False``, images will not be checked for the
98
            presence of NaNs.
99
        reader: Callable object that takes a path and returns a 4D tensor and a
100
            2D, :math:`4 \times 4` affine matrix. This can be used if your data
101
            is saved in a custom format, such as ``.npy`` (see example below).
102
            If the affine matrix is ``None``, an identity matrix will be used.
103
        **kwargs: Items that will be added to the image dictionary, e.g.
104
            acquisition parameters or image ID.
105
        verify_path: If ``True``, the path will be checked to see if it exists. If
106
            ``False``, the path will not be verified. This is useful when it is
107
            expensive to check the path, e.g., when reading a large dataset from a
108
            mounted drive.
109
110
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
111
    when needed.
112
113
    Example:
114
        >>> import torchio as tio
115
        >>> import numpy as np
116
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
117
        >>> image  # not loaded yet
118
        ScalarImage(path: t1.nii.gz; type: intensity)
119
        >>> times_two = 2 * image.data  # data is loaded and cached here
120
        >>> image
121
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
122
        >>> image.save('doubled_image.nii.gz')
123
        >>> def numpy_reader(path):
124
        ...     data = np.load(path).as_type(np.float32)
125
        ...     affine = np.eye(4)
126
        ...     return data, affine
127
        >>> image = tio.ScalarImage('t1.npy', reader=numpy_reader)
128
129
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
130
    .. _preprocessing: https://docs.torchio.org/transforms/preprocessing.html#intensity
131
    .. _augmentation: https://docs.torchio.org/transforms/augmentation.html#intensity
132
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
133
    .. _NiBabel docs on coordinates: https://nipy.org/nibabel/coordinate_systems.html#the-affine-matrix-as-a-transformation-between-spaces
134
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
135
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
136
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
137
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
138
    """
139
140
    def __init__(
141
        self,
142
        path: TypePath | Sequence[TypePath] | None = None,
143
        type: str | None = None,  # noqa: A002
144
        tensor: TypeData | None = None,
145
        affine: TypeData | None = None,
146
        check_nans: bool = False,  # removed by ITK by default
147
        reader: Callable[[TypePath], TypeDataAffine] = read_image,
148
        verify_path: bool = True,
149
        **kwargs: dict[str, Any],
150
    ):
151
        self.check_nans = check_nans
152
        self.reader = reader
153
154
        if type is None:
155
            warnings.warn(
156
                'Not specifying the image type is deprecated and will be'
157
                ' mandatory in the future. You can probably use'
158
                ' tio.ScalarImage or tio.LabelMap instead',
159
                FutureWarning,
160
                stacklevel=2,
161
            )
162
            type = INTENSITY  # noqa: A001
163
164
        if path is None and tensor is None:
165
            raise ValueError('A value for path or tensor must be given')
166
        self._loaded = False
167
168
        tensor = self._parse_tensor(tensor)
169
        affine = self._parse_affine(affine)
170
        if tensor is not None:
171
            self.set_data(tensor)
172
            self.affine = affine
173
            self._loaded = True
174
        for key in PROTECTED_KEYS:
175
            if key in kwargs:
176
                message = f'Key "{key}" is reserved. Use a different one'
177
                raise ValueError(message)
178
        if 'channels_last' in kwargs:
179
            message = (
180
                'The "channels_last" keyword argument is deprecated after'
181
                ' https://github.com/TorchIO-project/torchio/pull/685 and will be'
182
                ' removed in the future'
183
            )
184
            warnings.warn(message, FutureWarning, stacklevel=2)
185
186
        super().__init__(**kwargs)
187
        self.path = self._parse_path(path, verify=verify_path)
188
189
        self[PATH] = '' if self.path is None else str(self.path)
190
        self[STEM] = '' if self.path is None else get_stem(self.path)
191
        self[TYPE] = type
192
193
    def __repr__(self):
194
        properties = []
195
        properties.extend(
196
            [
197
                f'shape: {self.shape}',
198
                f'spacing: {self.get_spacing_string()}',
199
                f'orientation: {self.orientation_str}+',
200
            ]
201
        )
202
        if self._loaded:
203
            properties.append(f'dtype: {self.data.type()}')
204
            natural = humanize.naturalsize(self.memory, binary=True)
205
            properties.append(f'memory: {natural}')
206
        else:
207
            properties.append(f'path: "{self.path}"')
208
209
        properties = '; '.join(properties)
210
        string = f'{self.__class__.__name__}({properties})'
211
        return string
212
213
    def _repr_html_(self):
214
        try:
215
            from matplotlib import pyplot as plt
216
            from matplotlib.figure import Figure
217
        except ImportError:
218
            return self.__repr__()
219
220
        buffer = io.BytesIO()
221
        fig = self.plot(
222
            return_fig=True,
223
            output_path=buffer,
224
            show=False,
225
            savefig_kwargs={'bbox_inches': 'tight'},
226
        )
227
        assert isinstance(fig, Figure)
228
        plt.close(fig)
229
        buffer.seek(0)
230
231
        img_str = base64.b64encode(buffer.read()).decode('utf-8')
232
        html = f'<img src="data:image/png;base64,{img_str}"/>'
233
        return html
234
235
    def __getitem__(self, item):
236
        if isinstance(item, (slice, int, tuple)):
237
            return self._crop_from_slices(item)
238
239
        if item in (DATA, AFFINE):
240
            if item not in self:
241
                self.load()
242
        return super().__getitem__(item)
243
244
    def __array__(self):
245
        return self.data.numpy()
246
247
    def __copy__(self):
248
        kwargs = {
249
            TYPE: self.type,
250
            PATH: self.path,
251
        }
252
        if self._loaded:
253
            kwargs[TENSOR] = self.data
254
            kwargs[AFFINE] = self.affine
255
        for key, value in self.items():
256
            if key in PROTECTED_KEYS:
257
                continue
258
            kwargs[key] = value  # should I copy? deepcopy?
259
        new_image_class = type(self)
260
        new_image = new_image_class(
261
            check_nans=self.check_nans,
262
            reader=self.reader,
263
            **kwargs,
264
        )
265
        return new_image
266
267
    @property
268
    def data(self) -> torch.Tensor:
269
        """Tensor data (same as :class:`Image.tensor`)."""
270
        return self[DATA]
271
272
    @data.setter  # type: ignore[misc]
273
    @deprecated(version='0.18.16', reason=deprecation_message)
274
    def data(self, tensor: TypeData):
275
        self.set_data(tensor)
276
277
    def set_data(self, tensor: TypeData):
278
        """Store a 4D tensor in the :attr:`data` key and attribute.
279
280
        Args:
281
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
282
        """
283
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
284
        self._loaded = True
285
286
    @property
287
    def tensor(self) -> torch.Tensor:
288
        """Tensor data (same as :class:`Image.data`)."""
289
        return self.data
290
291
    @property
292
    def affine(self) -> np.ndarray:
293
        """Affine matrix to transform voxel indices into world coordinates."""
294
        # If path is a dir (probably DICOM), just load the data
295
        # Same if it's a list of paths (used to create a 4D image)
296
        # Finally, if we use a custom reader, SimpleITK probably won't be able
297
        # to read the metadata, so we resort to loading everything into memory
298
        is_custom_reader = self.reader is not read_image
299
        if self._loaded or self._is_dir() or self._is_multipath() or is_custom_reader:
300
            affine = self[AFFINE]
301
        else:
302
            assert self.path is not None
303
            assert isinstance(self.path, (str, Path))
304
            affine = read_affine(self.path)
305
        return affine
306
307
    @affine.setter
308
    def affine(self, matrix):
309
        self[AFFINE] = self._parse_affine(matrix)
310
311
    @property
312
    def type(self) -> str:  # noqa: A003
313
        return self[TYPE]
314
315
    @property
316
    def shape(self) -> TypeQuartetInt:
317
        """Tensor shape as :math:`(C, W, H, D)`."""
318
        custom_reader = self.reader is not read_image
319
        multipath = self._is_multipath()
320
        if isinstance(self.path, Path):
321
            is_dir = self.path.is_dir()
322
        shape: TypeQuartetInt
323
        if self._loaded or custom_reader or multipath or is_dir:
0 ignored issues
show
introduced by
The variable is_dir does not seem to be defined in case isinstance(self.path, Path) on line 320 is False. Are you sure this can never be the case?
Loading history...
324
            channels, si, sj, sk = self.data.shape
325
            shape = channels, si, sj, sk
326
        else:
327
            assert isinstance(self.path, (str, Path))
328
            shape = read_shape(self.path)
329
        return shape
330
331
    @property
332
    def spatial_shape(self) -> TypeTripletInt:
333
        """Tensor spatial shape as :math:`(W, H, D)`."""
334
        return self.shape[1:]
335
336
    def check_is_2d(self) -> None:
337
        if not self.is_2d():
338
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
339
            raise RuntimeError(message)
340
341
    @property
342
    def height(self) -> int:
343
        """Image height, if 2D."""
344
        self.check_is_2d()
345
        return self.spatial_shape[1]
346
347
    @property
348
    def width(self) -> int:
349
        """Image width, if 2D."""
350
        self.check_is_2d()
351
        return self.spatial_shape[0]
352
353
    @property
354
    def orientation(self) -> tuple[str, str, str]:
355
        """Orientation codes."""
356
        return nib.orientations.aff2axcodes(self.affine)
357
358
    @property
359
    def orientation_str(self) -> str:
360
        """Orientation as a string."""
361
        return ''.join(self.orientation)
362
363
    @property
364
    def direction(self) -> TypeDirection3D:
365
        _, _, direction = get_sitk_metadata_from_ras_affine(
366
            self.affine,
367
            lps=False,
368
        )
369
        return direction  # type: ignore[return-value]
370
371
    @property
372
    def spacing(self) -> tuple[float, float, float]:
373
        """Voxel spacing in mm."""
374
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
375
        sx, sy, sz = spacing
376
        return float(sx), float(sy), float(sz)
377
378
    @property
379
    def origin(self) -> tuple[float, float, float]:
380
        """Center of first voxel in array, in mm."""
381
        ox, oy, oz = self.affine[:3, 3]
382
        return ox, oy, oz
383
384
    @property
385
    def itemsize(self):
386
        """Element size of the data type."""
387
        return self.data.element_size()
388
389
    @property
390
    def memory(self) -> float:
391
        """Number of Bytes that the tensor takes in the RAM."""
392
        return np.prod(self.shape) * self.itemsize
393
394
    @property
395
    def bounds(self) -> np.ndarray:
396
        """Position of centers of voxels in smallest and largest indices."""
397
        ini = 0, 0, 0
398
        fin = np.array(self.spatial_shape) - 1
399
        point_ini = apply_affine(self.affine, ini)
400
        point_fin = apply_affine(self.affine, fin)
401
        return np.array((point_ini, point_fin))
402
403
    @property
404
    def num_channels(self) -> int:
405
        """Get the number of channels in the associated 4D tensor."""
406
        return len(self.data)
407
408
    def axis_name_to_index(self, axis: str) -> int:
409
        """Convert an axis name to an axis index.
410
411
        Args:
412
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
413
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
414
                versions and first letters are also valid, as only the first
415
                letter will be used.
416
417
        .. note:: If you are working with animals, you should probably use
418
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
419
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
420
            respectively.
421
422
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
423
            ``'Left'`` and ``'Right'``.
424
        """
425
        # Top and bottom are used for the vertical 2D axis as the use of
426
        # Height vs Horizontal might be ambiguous
427
428
        if not isinstance(axis, str):
429
            raise ValueError('Axis must be a string')
430
        axis = axis[0].upper()
431
432
        # Generally, TorchIO tensors are (C, W, H, D)
433
        if axis in 'TB':  # Top, Bottom
434
            return -2
435
        else:
436
            try:
437
                index = self.orientation.index(axis)
438
            except ValueError:
439
                index = self.orientation.index(self.flip_axis(axis))
440
            # Return negative indices so that it does not matter whether we
441
            # refer to spatial dimensions or not
442
            index = -3 + index
443
            return index
444
445
    @staticmethod
446
    def flip_axis(axis: str) -> str:
447
        """Return the opposite axis label. For example, ``'L'`` -> ``'R'``.
448
449
        Args:
450
            axis: Axis label, such as ``'L'`` or ``'left'``.
451
        """
452
        labels = 'LRPAISTBDV'
453
        first = labels[::2]
454
        last = labels[1::2]
455
        flip_dict = dict(zip(first + last, last + first, strict=True))
456
        axis = axis[0].upper()
457
        flipped_axis = flip_dict.get(axis)
458
        if flipped_axis is None:
459
            values = ', '.join(labels)
460
            message = f'Axis not understood. Please use one of: {values}'
461
            raise ValueError(message)
462
        return flipped_axis
463
464
    def get_spacing_string(self) -> str:
465
        strings = [f'{n:.2f}' for n in self.spacing]
466
        string = f'({", ".join(strings)})'
467
        return string
468
469
    def get_bounds(self) -> TypeBounds:
470
        """Get minimum and maximum world coordinates occupied by the image."""
471
        first_index = 3 * (-0.5,)
472
        last_index = np.array(self.spatial_shape) - 0.5
473
        first_point = apply_affine(self.affine, first_index)
474
        last_point = apply_affine(self.affine, last_index)
475
        array = np.array((first_point, last_point))
476
        bounds_x, bounds_y, bounds_z = array.T.tolist()  # type: ignore[misc]
477
        return bounds_x, bounds_y, bounds_z  # type: ignore[return-value]
478
479
    def _parse_single_path(
480
        self,
481
        path: TypePath,
482
        *,
483
        verify: bool = True,
484
    ) -> Path:
485
        if isinstance(path, (torch.Tensor, np.ndarray)):
486
            class_name = self.__class__.__name__
487
            message = (
488
                'Expected type str or Path but found a tensor/array. Instead of'
489
                f' {class_name}(your_tensor),'
490
                f' use {class_name}(tensor=your_tensor).'
491
            )
492
            raise TypeError(message)
493
        try:
494
            path = Path(path).expanduser()
495
        except TypeError as err:
496
            message = (
497
                f'Expected type str or Path but found an object with type'
498
                f' {type(path)} instead'
499
            )
500
            raise TypeError(message) from err
501
        except RuntimeError as err:
502
            message = f'Conversion to path not possible for variable: {path}'
503
            raise RuntimeError(message) from err
504
        if not verify:
505
            return path
506
507
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
508
            raise FileNotFoundError(f'File not found: "{path}"')
509
        return path
510
511
    def _parse_path(
512
        self,
513
        path: TypePath | Sequence[TypePath] | None,
514
        *,
515
        verify: bool = True,
516
    ) -> Path | list[Path] | None:
517
        if path is None:
518
            return None
519
        elif isinstance(path, dict):
520
            # https://github.com/TorchIO-project/torchio/pull/838
521
            raise TypeError('The path argument cannot be a dictionary')
522
        elif self._is_paths_sequence(path):
523
            return [self._parse_single_path(p, verify=verify) for p in path]  # type: ignore[union-attr]
524
        else:
525
            return self._parse_single_path(path, verify=verify)  # type: ignore[arg-type]
526
527
    def _parse_tensor(
528
        self,
529
        tensor: TypeData | None,
530
        none_ok: bool = True,
531
    ) -> torch.Tensor | None:
532
        if tensor is None:
533
            if none_ok:
534
                return None
535
            else:
536
                raise RuntimeError('Input tensor cannot be None')
537
        if isinstance(tensor, np.ndarray):
538
            tensor = check_uint_to_int(tensor)
539
            tensor = torch.as_tensor(tensor)
540
        elif not isinstance(tensor, torch.Tensor):
541
            message = (
542
                'Input tensor must be a PyTorch tensor or NumPy array,'
543
                f' but type "{type(tensor)}" was found'
544
            )
545
            raise TypeError(message)
546
        ndim = tensor.ndim
547
        if ndim != 4:
548
            raise ValueError(f'Input tensor must be 4D, but it is {ndim}D')
549
        if tensor.dtype == torch.bool:
550
            tensor = tensor.to(torch.uint8)
551
        if self.check_nans and torch.isnan(tensor).any():
552
            warnings.warn('NaNs found in tensor', RuntimeWarning, stacklevel=2)
553
        return tensor
554
555
    @staticmethod
556
    def _parse_tensor_shape(tensor: torch.Tensor) -> TypeData:
557
        return ensure_4d(tensor)
558
559
    @staticmethod
560
    def _parse_affine(affine: TypeData | None) -> np.ndarray:
561
        if affine is None:
562
            return np.eye(4)
563
        if isinstance(affine, torch.Tensor):
564
            affine = affine.numpy()
565
        if not isinstance(affine, np.ndarray):
566
            bad_type = type(affine)
567
            raise TypeError(f'Affine must be a NumPy array, not {bad_type}')
568
        if affine.shape != (4, 4):
569
            bad_shape = affine.shape
570
            raise ValueError(f'Affine shape must be (4, 4), not {bad_shape}')
571
        return affine.astype(np.float64)
572
573
    @staticmethod
574
    def _is_paths_sequence(path: TypePath | Sequence[TypePath] | None) -> bool:
575
        is_not_string = not isinstance(path, str)
576
        return is_not_string and is_iterable(path)
577
578
    def _is_multipath(self) -> bool:
579
        return self._is_paths_sequence(self.path)
580
581
    def _is_dir(self) -> bool:
582
        is_sequence = self._is_multipath()
583
        if is_sequence:
584
            return False
585
        elif self.path is None:
586
            return False
587
        else:
588
            assert isinstance(self.path, Path)
589
            return self.path.is_dir()
590
591
    def load(self) -> None:
592
        r"""Load the image from disk.
593
594
        Returns:
595
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
596
            :math:`4 \times 4` affine matrix to convert voxel indices to world
597
            coordinates.
598
        """
599
        if self._loaded:
600
            return
601
602
        paths: list[Path]
603
        if self._is_multipath():
604
            paths = self.path  # type: ignore[assignment]
605
        else:
606
            paths = [self.path]  # type: ignore[list-item]
607
        tensor, affine = self.read_and_check(paths[0])
608
        tensors = [tensor]
609
        for path in paths[1:]:
610
            new_tensor, new_affine = self.read_and_check(path)
611
            if not np.array_equal(affine, new_affine):
612
                message = (
613
                    'Files have different affine matrices.'
614
                    f'\nMatrix of {paths[0]}:'
615
                    f'\n{affine}'
616
                    f'\nMatrix of {path}:'
617
                    f'\n{new_affine}'
618
                )
619
                warnings.warn(message, RuntimeWarning, stacklevel=2)
620
            if not tensor.shape[1:] == new_tensor.shape[1:]:
621
                message = (
622
                    f'Files shape do not match, found {tensor.shape}'
623
                    f'and {new_tensor.shape}'
624
                )
625
                raise RuntimeError(message)
626
            tensors.append(new_tensor)
627
        tensor = torch.cat(tensors)
628
        self.set_data(tensor)
629
        self.affine = affine
630
        self._loaded = True
631
632
    def unload(self) -> None:
633
        """Unload the image from memory.
634
635
        Raises:
636
            RuntimeError: If the images has not been loaded yet or if no path
637
                is available.
638
        """
639
        if not self._loaded:
640
            message = 'Image cannot be unloaded as it has not been loaded yet'
641
            raise RuntimeError(message)
642
        if self.path is None:
643
            message = (
644
                'Cannot unload image as no path is available'
645
                ' from where the image could be loaded again'
646
            )
647
            raise RuntimeError(message)
648
        self[DATA] = None
649
        self[AFFINE] = None
650
        self._loaded = False
651
652
    def read_and_check(self, path: TypePath) -> TypeDataAffine:
653
        tensor, affine = self.reader(path)
654
        # Make sure the data type is compatible with PyTorch
655
        if self.reader is not read_image and isinstance(tensor, np.ndarray):
656
            tensor = check_uint_to_int(tensor)
657
        tensor = self._parse_tensor_shape(tensor)  # type: ignore[assignment]
658
        tensor = self._parse_tensor(tensor)  # type: ignore[assignment]
659
        affine = self._parse_affine(affine)
660
        if self.check_nans and torch.isnan(tensor).any():
661
            warnings.warn(
662
                f'NaNs found in file "{path}"',
663
                RuntimeWarning,
664
                stacklevel=2,
665
            )
666
        return tensor, affine
667
668
    def save(self, path: TypePath, squeeze: bool | None = None) -> None:
669
        """Save image to disk.
670
671
        Args:
672
            path: String or instance of :class:`pathlib.Path`.
673
            squeeze: Whether to remove singleton dimensions before saving.
674
                If ``None``, the array will be squeezed if the output format is
675
                JP(E)G, PNG, BMP or TIF(F).
676
        """
677
        write_image(
678
            self.data,
679
            self.affine,
680
            path,
681
            squeeze=squeeze,
682
        )
683
684
    def is_2d(self) -> bool:
685
        return self.shape[-1] == 1
686
687
    def numpy(self) -> np.ndarray:
688
        """Get a NumPy array containing the image data."""
689
        return np.asarray(self)
690
691
    def as_sitk(self, **kwargs) -> sitk.Image:
692
        """Get the image as an instance of :class:`sitk.Image`."""
693
        return nib_to_sitk(self.data, self.affine, **kwargs)
694
695
    @classmethod
696
    def from_sitk(cls, sitk_image):
697
        """Instantiate a new TorchIO image from a :class:`sitk.Image`.
698
699
        Example:
700
            >>> import torchio as tio
701
            >>> import SimpleITK as sitk
702
            >>> sitk_image = sitk.Image(20, 30, 40, sitk.sitkUInt16)
703
            >>> tio.LabelMap.from_sitk(sitk_image)
704
            LabelMap(shape: (1, 20, 30, 40); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 93.8 KiB; dtype: torch.IntTensor)
705
            >>> sitk_image = sitk.Image((224, 224), sitk.sitkVectorFloat32, 3)
706
            >>> tio.ScalarImage.from_sitk(sitk_image)
707
            ScalarImage(shape: (3, 224, 224, 1); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 588.0 KiB; dtype: torch.FloatTensor)
708
        """
709
        tensor, affine = sitk_to_nib(sitk_image)
710
        return cls(tensor=tensor, affine=affine)
711
712
    def as_pil(self, transpose=True):
713
        """Get the image as an instance of :class:`PIL.Image`.
714
715
        .. note:: Values will be clamped to 0-255 and cast to uint8.
716
717
        .. note:: To use this method, Pillow needs to be installed:
718
            ``pip install Pillow``.
719
        """
720
        try:
721
            from PIL import Image as ImagePIL
722
        except ModuleNotFoundError as e:
723
            message = 'Please install Pillow to use Image.as_pil(): pip install Pillow'
724
            raise RuntimeError(message) from e
725
726
        self.check_is_2d()
727
        tensor = self.data
728
        if len(tensor) not in (1, 3, 4):
729
            raise NotImplementedError(
730
                'Only 1, 3 or 4 channels are supported for conversion to Pillow image'
731
            )
732
        if len(tensor) == 1:
733
            tensor = torch.cat(3 * [tensor])
734
        if transpose:
735
            tensor = tensor.permute(3, 2, 1, 0)
736
        else:
737
            tensor = tensor.permute(3, 1, 2, 0)
738
        array = tensor.clamp(0, 255).numpy()[0]
739
        return ImagePIL.fromarray(array.astype(np.uint8))
740
741
    def to_gif(
742
        self,
743
        axis: int,
744
        duration: float,  # of full gif
745
        output_path: TypePath,
746
        loop: int = 0,
747
        rescale: bool = True,
748
        optimize: bool = True,
749
        reverse: bool = False,
750
    ) -> None:
751
        """Save an animated GIF of the image.
752
753
        Args:
754
            axis: Spatial axis (0, 1 or 2).
755
            duration: Duration of the full animation in seconds.
756
            output_path: Path to the output GIF file.
757
            loop: Number of times the GIF should loop.
758
                ``0`` means that it will loop forever.
759
            rescale: Use :class:`~torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity`
760
                to rescale the intensity values to :math:`[0, 255]`.
761
            optimize: If ``True``, attempt to compress the palette by
762
                eliminating unused colors. This is only useful if the palette
763
                can be compressed to the next smaller power of 2 elements.
764
            reverse: Reverse the temporal order of frames.
765
        """
766
        from ..visualization import make_gif  # avoid circular import
767
768
        make_gif(
769
            self.data,
770
            axis,
771
            duration,
772
            output_path,
773
            loop=loop,
774
            rescale=rescale,
775
            optimize=optimize,
776
            reverse=reverse,
777
        )
778
779
    def to_ras(self) -> Image:
780
        if self.orientation_str != 'RAS':
781
            from ..transforms.preprocessing.spatial.to_canonical import ToCanonical
782
783
            return ToCanonical()(self)
784
        return self
785
786
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
787
        """Get image center in RAS+ or LPS+ coordinates.
788
789
        Args:
790
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
791
                the first dimension grows towards the left, etc. Otherwise, the
792
                coordinates will be in RAS+ orientation.
793
        """
794
        size = np.array(self.spatial_shape)
795
        center_index = (size - 1) / 2
796
        r, a, s = apply_affine(self.affine, center_index)
797
        if lps:
798
            return (-r, -a, s)
799
        else:
800
            return (r, a, s)
801
802
    def set_check_nans(self, check_nans: bool) -> None:
803
        self.check_nans = check_nans
804
805
    def plot(self, return_fig: bool = False, **kwargs) -> None | Figure:
806
        """Plot image."""
807
        if self.is_2d():
808
            self.as_pil().show()
809
        else:
810
            from ..visualization import plot_volume  # avoid circular import
811
812
            figure = plot_volume(self, **kwargs)
813
            if return_fig:
814
                assert figure is not None
815
                return figure
816
817
    def show(self, viewer_path: TypePath | None = None) -> None:
818
        """Open the image using external software.
819
820
        Args:
821
            viewer_path: Path to the application used to view the image. If
822
                ``None``, the value of the environment variable
823
                ``SITK_SHOW_COMMAND`` will be used. If this variable is also
824
                not set, TorchIO will try to guess the location of
825
                `ITK-SNAP <http://www.itksnap.org/pmwiki/pmwiki.php>`_ and
826
                `3D Slicer <https://www.slicer.org/>`_.
827
828
        Raises:
829
            RuntimeError: If the viewer is not found.
830
        """
831
        sitk_image = self.as_sitk()
832
        image_viewer = sitk.ImageViewer()
833
        # This is so that 3D Slicer creates segmentation nodes from label maps
834
        if self.__class__.__name__ == 'LabelMap':
835
            image_viewer.SetFileExtension('.seg.nrrd')
836
        if viewer_path is not None:
837
            image_viewer.SetApplication(str(viewer_path))
838
        try:
839
            image_viewer.Execute(sitk_image)
840
        except RuntimeError as e:
841
            viewer_path = guess_external_viewer()
842
            if viewer_path is None:
843
                message = (
844
                    'No external viewer has been found. Please set the'
845
                    ' environment variable SITK_SHOW_COMMAND to a viewer of'
846
                    ' your choice'
847
                )
848
                raise RuntimeError(message) from e
849
            image_viewer.SetApplication(str(viewer_path))
850
            image_viewer.Execute(sitk_image)
851
852
    def _crop_from_slices(
853
        self,
854
        slices: TypeSlice | tuple[TypeSlice, ...],
855
    ) -> Image:
856
        from ..transforms import Crop
857
858
        slices_tuple = to_tuple(slices)  # type: ignore[assignment]
859
        cropping: list[int] = []
860
        for dim, slice_ in enumerate(slices_tuple):
861
            if isinstance(slice_, slice):
862
                pass
863
            elif slice_ is Ellipsis:
864
                message = 'Ellipsis slicing is not supported yet'
865
                raise NotImplementedError(message)
866
            elif isinstance(slice_, int):
867
                slice_ = slice(slice_, slice_ + 1)  # type: ignore[assignment]
868
            else:
869
                message = f'Slice type not understood: "{type(slice_)}"'
870
                raise TypeError(message)
871
            shape_dim = self.spatial_shape[dim]
872
            assert isinstance(slice_, slice)
873
            start, stop, step = slice_.indices(shape_dim)
874
            if step != 1:
875
                message = (
876
                    'Slicing with steps different from 1 is not supported yet.'
877
                    ' Use the Crop transform instead'
878
                )
879
                raise ValueError(message)
880
            crop_ini = start
881
            crop_fin = shape_dim - stop
882
            cropping.extend([crop_ini, crop_fin])
883
        while dim < 2:
0 ignored issues
show
introduced by
The variable dim does not seem to be defined in case the for loop on line 860 is not entered. Are you sure this can never be the case?
Loading history...
884
            cropping.extend([0, 0])
885
            dim += 1
886
        w_ini, w_fin, h_ini, h_fin, d_ini, d_fin = cropping
887
        cropping_arg = w_ini, w_fin, h_ini, h_fin, d_ini, d_fin  # making mypy happy
888
        return Crop(cropping_arg)(self)  # type: ignore[return-value]
889
890
891
class ScalarImage(Image):
892
    """Image whose pixel values represent scalars.
893
894
    Example:
895
        >>> import torch
896
        >>> import torchio as tio
897
        >>> # Loading from a file
898
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
899
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
900
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
901
        >>> data, affine = image.data, image.affine
902
        >>> affine.shape
903
        (4, 4)
904
        >>> image.data is image[tio.DATA]
905
        True
906
        >>> image.data is image.tensor
907
        True
908
        >>> type(image.data)
909
        torch.Tensor
910
911
    See :class:`~torchio.Image` for more information.
912
    """
913
914
    def __init__(self, *args, **kwargs):
915
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
916
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
917
        kwargs.update({'type': INTENSITY})
918
        super().__init__(*args, **kwargs)
919
920
    def hist(self, **kwargs) -> None:
921
        """Plot histogram."""
922
        from ..visualization import plot_histogram
923
924
        x = self.data.flatten().numpy()
925
        plot_histogram(x, **kwargs)
926
927
    def to_video(
928
        self,
929
        output_path: TypePath,
930
        frame_rate: float | None = 15,
931
        seconds: float | None = None,
932
        direction: str = 'I',
933
        verbosity: str = 'error',
934
    ) -> None:
935
        """Create a video showing all image slices along a specified direction.
936
937
        Args:
938
            output_path: Path to the output video file.
939
            frame_rate: Number of frames per second (FPS).
940
            seconds: Target duration of the full video.
941
            direction:
942
            verbosity:
943
944
        .. note:: Only ``frame_rate`` or ``seconds`` may (and must) be specified.
945
        """
946
        from ..visualization import make_video  # avoid circular import
947
948
        make_video(
949
            self.to_ras(),  # type: ignore[arg-type]
950
            output_path,
951
            frame_rate=frame_rate,
952
            seconds=seconds,
953
            direction=direction,
954
            verbosity=verbosity,
955
        )
956
957
958
class LabelMap(Image):
959
    """Image whose pixel values represent segmentation labels.
960
961
    A sequence of paths to 3D images can be passed to create a 4D image.
962
    This is useful to create a
963
    `tissue probability map (TPM) <https://andysbrainbook.readthedocs.io/en/latest/SPM/SPM_Short_Course/SPM_04_Preprocessing/04_SPM_Segmentation.html#tissue-probability-maps>`,
964
    which contains the probability of each voxel belonging to a certain tissue type,
965
    or to create a label map with overlapping labels.
966
967
    Intensity transforms are not applied to these images.
968
969
    Nearest neighbor interpolation is always used to resample label maps,
970
    independently of the specified interpolation type in the transform
971
    instantiation.
972
973
    Example:
974
        >>> import torch
975
        >>> import torchio as tio
976
        >>> binary_tensor = torch.rand(1, 128, 128, 68) > 0.5
977
        >>> label_map = tio.LabelMap(tensor=binary_tensor)  # from a tensor
978
        >>> label_map = tio.LabelMap('t1_seg.nii.gz')  # from a file
979
        >>> # Create a 4D tissue probability map from different 3D images
980
        >>> tissues = 'gray_matter.nii.gz', 'white_matter.nii.gz', 'csf.nii.gz'
981
        >>> tpm = tio.LabelMap(tissues)
982
983
    See :class:`~torchio.Image` for more information.
984
    """
985
986
    def __init__(self, *args, **kwargs):
987
        if 'type' in kwargs and kwargs['type'] != LABEL:
988
            raise ValueError('Type of LabelMap is always torchio.LABEL')
989
        kwargs.update({'type': LABEL})
990
        super().__init__(*args, **kwargs)
991
992
    def count_nonzero(self) -> int:
993
        """Get the number of voxels that are not 0."""
994
        return int(self.data.count_nonzero().item())
995
996
    def count_labels(self) -> dict[int, int]:
997
        """Get the number of voxels in each label."""
998
        values_list = self.data.flatten().tolist()
999
        counter = Counter(values_list)
1000
        counts = {label: counter[label] for label in sorted(counter)}
1001
        return counts
1002