Passed
Push — main ( 287682...fc78a5 )
by Fernando
01:27
created

torchio.data.image.ScalarImage.__init__()   A

Complexity

Conditions 3

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 5
nop 3
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
from __future__ import annotations
2
3
import warnings
4
from collections import Counter
5
from collections.abc import Sequence
6
from pathlib import Path
7
from typing import Any
8
from typing import Callable
9
10
import humanize
11
import nibabel as nib
12
import numpy as np
13
import SimpleITK as sitk
14
import torch
15
from deprecated import deprecated
16
from nibabel.affines import apply_affine
17
18
from ..constants import AFFINE
19
from ..constants import DATA
20
from ..constants import INTENSITY
21
from ..constants import LABEL
22
from ..constants import PATH
23
from ..constants import STEM
24
from ..constants import TENSOR
25
from ..constants import TYPE
26
from ..types import TypeData
27
from ..types import TypeDataAffine
28
from ..types import TypeDirection3D
29
from ..types import TypePath
30
from ..types import TypeQuartetInt
31
from ..types import TypeSlice
32
from ..types import TypeTripletFloat
33
from ..types import TypeTripletInt
34
from ..utils import get_stem
35
from ..utils import guess_external_viewer
36
from ..utils import is_iterable
37
from ..utils import to_tuple
38
from .io import check_uint_to_int
39
from .io import ensure_4d
40
from .io import get_rotation_and_spacing_from_affine
41
from .io import get_sitk_metadata_from_ras_affine
42
from .io import nib_to_sitk
43
from .io import read_affine
44
from .io import read_image
45
from .io import read_shape
46
from .io import sitk_to_nib
47
from .io import write_image
48
49
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
50
TypeBound = tuple[float, float]
51
TypeBounds = tuple[TypeBound, TypeBound, TypeBound]
52
53
deprecation_message = (
54
    'Setting the image data with the property setter is deprecated. Use the'
55
    ' set_data() method instead'
56
)
57
58
59
class Image(dict):
60
    r"""TorchIO image.
61
62
    For information about medical image orientation, check out `NiBabel docs`_,
63
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
64
    `SimpleITK docs`_.
65
66
    Args:
67
        path: Path to a file or sequence of paths to files that can be read by
68
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
69
            DICOM files. If :attr:`tensor` is given, the data in
70
            :attr:`path` will not be read.
71
            If a sequence of paths is given, data
72
            will be concatenated on the channel dimension so spatial
73
            dimensions must match.
74
        type: Type of image, such as :attr:`torchio.INTENSITY` or
75
            :attr:`torchio.LABEL`. This will be used by the transforms to
76
            decide whether to apply an operation, or which interpolation to use
77
            when resampling. For example, `preprocessing`_ and `augmentation`_
78
            intensity transforms will only be applied to images with type
79
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
80
            all types, and nearest neighbor interpolation is always used to
81
            resample images with type :attr:`torchio.LABEL`.
82
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
83
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
84
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
85
            :class:`torch.Tensor` or NumPy array with dimensions
86
            :math:`(C, W, H, D)`.
87
        affine: :math:`4 \times 4` matrix to convert voxel coordinates to world
88
            coordinates. If ``None``, an identity matrix will be used. See the
89
            `NiBabel docs on coordinates`_ for more information.
90
        check_nans: If ``True``, issues a warning if NaNs are found
91
            in the image. If ``False``, images will not be checked for the
92
            presence of NaNs.
93
        reader: Callable object that takes a path and returns a 4D tensor and a
94
            2D, :math:`4 \times 4` affine matrix. This can be used if your data
95
            is saved in a custom format, such as ``.npy`` (see example below).
96
            If the affine matrix is ``None``, an identity matrix will be used.
97
        **kwargs: Items that will be added to the image dictionary, e.g.
98
            acquisition parameters or image ID.
99
        verify_path: If ``True``, the path will be checked to see if it exists. If
100
            ``False``, the path will not be verified. This is useful when it is
101
            expensive to check the path, e.g., when reading a large dataset from a
102
            mounted drive.
103
104
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
105
    when needed.
106
107
    Example:
108
        >>> import torchio as tio
109
        >>> import numpy as np
110
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
111
        >>> image  # not loaded yet
112
        ScalarImage(path: t1.nii.gz; type: intensity)
113
        >>> times_two = 2 * image.data  # data is loaded and cached here
114
        >>> image
115
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
116
        >>> image.save('doubled_image.nii.gz')
117
        >>> def numpy_reader(path):
118
        ...     data = np.load(path).as_type(np.float32)
119
        ...     affine = np.eye(4)
120
        ...     return data, affine
121
        >>> image = tio.ScalarImage('t1.npy', reader=numpy_reader)
122
123
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
124
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
125
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
126
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
127
    .. _NiBabel docs on coordinates: https://nipy.org/nibabel/coordinate_systems.html#the-affine-matrix-as-a-transformation-between-spaces
128
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
129
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
130
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
131
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
132
    """
133
134
    def __init__(
135
        self,
136
        path: TypePath | Sequence[TypePath] | None = None,
137
        type: str | None = None,  # noqa: A002
138
        tensor: TypeData | None = None,
139
        affine: TypeData | None = None,
140
        check_nans: bool = False,  # removed by ITK by default
141
        reader: Callable[[TypePath], TypeDataAffine] = read_image,
142
        verify_path: bool = True,
143
        **kwargs: dict[str, Any],
144
    ):
145
        self.check_nans = check_nans
146
        self.reader = reader
147
148
        if type is None:
149
            warnings.warn(
150
                'Not specifying the image type is deprecated and will be'
151
                ' mandatory in the future. You can probably use'
152
                ' tio.ScalarImage or tio.LabelMap instead',
153
                FutureWarning,
154
                stacklevel=2,
155
            )
156
            type = INTENSITY  # noqa: A001
157
158
        if path is None and tensor is None:
159
            raise ValueError('A value for path or tensor must be given')
160
        self._loaded = False
161
162
        tensor = self._parse_tensor(tensor)
163
        affine = self._parse_affine(affine)
164
        if tensor is not None:
165
            self.set_data(tensor)
166
            self.affine = affine
167
            self._loaded = True
168
        for key in PROTECTED_KEYS:
169
            if key in kwargs:
170
                message = f'Key "{key}" is reserved. Use a different one'
171
                raise ValueError(message)
172
        if 'channels_last' in kwargs:
173
            message = (
174
                'The "channels_last" keyword argument is deprecated after'
175
                ' https://github.com/TorchIO-project/torchio/pull/685 and will be'
176
                ' removed in the future'
177
            )
178
            warnings.warn(message, FutureWarning, stacklevel=2)
179
180
        super().__init__(**kwargs)
181
        self.path = self._parse_path(path, verify=verify_path)
182
183
        self[PATH] = '' if self.path is None else str(self.path)
184
        self[STEM] = '' if self.path is None else get_stem(self.path)
185
        self[TYPE] = type
186
187
    def __repr__(self):
188
        properties = []
189
        properties.extend(
190
            [
191
                f'shape: {self.shape}',
192
                f'spacing: {self.get_spacing_string()}',
193
                f'orientation: {"".join(self.orientation)}+',
194
            ]
195
        )
196
        if self._loaded:
197
            properties.append(f'dtype: {self.data.type()}')
198
            natural = humanize.naturalsize(self.memory, binary=True)
199
            properties.append(f'memory: {natural}')
200
        else:
201
            properties.append(f'path: "{self.path}"')
202
203
        properties = '; '.join(properties)
204
        string = f'{self.__class__.__name__}({properties})'
205
        return string
206
207
    def __getitem__(self, item):
208
        if isinstance(item, (slice, int, tuple)):
209
            return self._crop_from_slices(item)
210
211
        if item in (DATA, AFFINE):
212
            if item not in self:
213
                self.load()
214
        return super().__getitem__(item)
215
216
    def __array__(self):
217
        return self.data.numpy()
218
219
    def __copy__(self):
220
        kwargs = {
221
            TYPE: self.type,
222
            PATH: self.path,
223
        }
224
        if self._loaded:
225
            kwargs[TENSOR] = self.data
226
            kwargs[AFFINE] = self.affine
227
        for key, value in self.items():
228
            if key in PROTECTED_KEYS:
229
                continue
230
            kwargs[key] = value  # should I copy? deepcopy?
231
        new_image_class = type(self)
232
        new_image = new_image_class(
233
            check_nans=self.check_nans,
234
            reader=self.reader,
235
            **kwargs,
236
        )
237
        return new_image
238
239
    @property
240
    def data(self) -> torch.Tensor:
241
        """Tensor data (same as :class:`Image.tensor`)."""
242
        return self[DATA]
243
244
    @data.setter  # type: ignore[misc]
245
    @deprecated(version='0.18.16', reason=deprecation_message)
246
    def data(self, tensor: TypeData):
247
        self.set_data(tensor)
248
249
    def set_data(self, tensor: TypeData):
250
        """Store a 4D tensor in the :attr:`data` key and attribute.
251
252
        Args:
253
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
254
        """
255
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
256
257
    @property
258
    def tensor(self) -> torch.Tensor:
259
        """Tensor data (same as :class:`Image.data`)."""
260
        return self.data
261
262
    @property
263
    def affine(self) -> np.ndarray:
264
        """Affine matrix to transform voxel indices into world coordinates."""
265
        # If path is a dir (probably DICOM), just load the data
266
        # Same if it's a list of paths (used to create a 4D image)
267
        # Finally, if we use a custom reader, SimpleITK probably won't be able
268
        # to read the metadata, so we resort to loading everything into memory
269
        is_custom_reader = self.reader is not read_image
270
        if self._loaded or self._is_dir() or self._is_multipath() or is_custom_reader:
271
            affine = self[AFFINE]
272
        else:
273
            assert self.path is not None
274
            assert isinstance(self.path, (str, Path))
275
            affine = read_affine(self.path)
276
        return affine
277
278
    @affine.setter
279
    def affine(self, matrix):
280
        self[AFFINE] = self._parse_affine(matrix)
281
282
    @property
283
    def type(self) -> str:  # noqa: A003
284
        return self[TYPE]
285
286
    @property
287
    def shape(self) -> TypeQuartetInt:
288
        """Tensor shape as :math:`(C, W, H, D)`."""
289
        custom_reader = self.reader is not read_image
290
        multipath = self._is_multipath()
291
        if isinstance(self.path, Path):
292
            is_dir = self.path.is_dir()
293
        shape: TypeQuartetInt
294
        if self._loaded or custom_reader or multipath or is_dir:
0 ignored issues
show
introduced by
The variable is_dir does not seem to be defined in case isinstance(self.path, Path) on line 291 is False. Are you sure this can never be the case?
Loading history...
295
            channels, si, sj, sk = self.data.shape
296
            shape = channels, si, sj, sk
297
        else:
298
            assert isinstance(self.path, (str, Path))
299
            shape = read_shape(self.path)
300
        return shape
301
302
    @property
303
    def spatial_shape(self) -> TypeTripletInt:
304
        """Tensor spatial shape as :math:`(W, H, D)`."""
305
        return self.shape[1:]
306
307
    def check_is_2d(self) -> None:
308
        if not self.is_2d():
309
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
310
            raise RuntimeError(message)
311
312
    @property
313
    def height(self) -> int:
314
        """Image height, if 2D."""
315
        self.check_is_2d()
316
        return self.spatial_shape[1]
317
318
    @property
319
    def width(self) -> int:
320
        """Image width, if 2D."""
321
        self.check_is_2d()
322
        return self.spatial_shape[0]
323
324
    @property
325
    def orientation(self) -> tuple[str, str, str]:
326
        """Orientation codes."""
327
        return nib.aff2axcodes(self.affine)
328
329
    @property
330
    def direction(self) -> TypeDirection3D:
331
        _, _, direction = get_sitk_metadata_from_ras_affine(
332
            self.affine,
333
            lps=False,
334
        )
335
        return direction  # type: ignore[return-value]
336
337
    @property
338
    def spacing(self) -> tuple[float, float, float]:
339
        """Voxel spacing in mm."""
340
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
341
        sx, sy, sz = spacing
342
        return sx, sy, sz
343
344
    @property
345
    def origin(self) -> tuple[float, float, float]:
346
        """Center of first voxel in array, in mm."""
347
        ox, oy, oz = self.affine[:3, 3]
348
        return ox, oy, oz
349
350
    @property
351
    def itemsize(self):
352
        """Element size of the data type."""
353
        return self.data.element_size()
354
355
    @property
356
    def memory(self) -> float:
357
        """Number of Bytes that the tensor takes in the RAM."""
358
        return np.prod(self.shape) * self.itemsize
359
360
    @property
361
    def bounds(self) -> np.ndarray:
362
        """Position of centers of voxels in smallest and largest indices."""
363
        ini = 0, 0, 0
364
        fin = np.array(self.spatial_shape) - 1
365
        point_ini = apply_affine(self.affine, ini)
366
        point_fin = apply_affine(self.affine, fin)
367
        return np.array((point_ini, point_fin))
368
369
    @property
370
    def num_channels(self) -> int:
371
        """Get the number of channels in the associated 4D tensor."""
372
        return len(self.data)
373
374
    def axis_name_to_index(self, axis: str) -> int:
375
        """Convert an axis name to an axis index.
376
377
        Args:
378
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
379
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
380
                versions and first letters are also valid, as only the first
381
                letter will be used.
382
383
        .. note:: If you are working with animals, you should probably use
384
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
385
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
386
            respectively.
387
388
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
389
            ``'Left'`` and ``'Right'``.
390
        """
391
        # Top and bottom are used for the vertical 2D axis as the use of
392
        # Height vs Horizontal might be ambiguous
393
394
        if not isinstance(axis, str):
395
            raise ValueError('Axis must be a string')
396
        axis = axis[0].upper()
397
398
        # Generally, TorchIO tensors are (C, W, H, D)
399
        if axis in 'TB':  # Top, Bottom
400
            return -2
401
        else:
402
            try:
403
                index = self.orientation.index(axis)
404
            except ValueError:
405
                index = self.orientation.index(self.flip_axis(axis))
406
            # Return negative indices so that it does not matter whether we
407
            # refer to spatial dimensions or not
408
            index = -3 + index
409
            return index
410
411
    @staticmethod
412
    def flip_axis(axis: str) -> str:
413
        """Return the opposite axis label. For example, ``'L'`` -> ``'R'``.
414
415
        Args:
416
            axis: Axis label, such as ``'L'`` or ``'left'``.
417
        """
418
        labels = 'LRPAISTBDV'
419
        first = labels[::2]
420
        last = labels[1::2]
421
        flip_dict = dict(zip(first + last, last + first))
422
        axis = axis[0].upper()
423
        flipped_axis = flip_dict.get(axis)
424
        if flipped_axis is None:
425
            values = ', '.join(labels)
426
            message = f'Axis not understood. Please use one of: {values}'
427
            raise ValueError(message)
428
        return flipped_axis
429
430
    def get_spacing_string(self) -> str:
431
        strings = [f'{n:.2f}' for n in self.spacing]
432
        string = f'({", ".join(strings)})'
433
        return string
434
435
    def get_bounds(self) -> TypeBounds:
436
        """Get minimum and maximum world coordinates occupied by the image."""
437
        first_index = 3 * (-0.5,)
438
        last_index = np.array(self.spatial_shape) - 0.5
439
        first_point = apply_affine(self.affine, first_index)
440
        last_point = apply_affine(self.affine, last_index)
441
        array = np.array((first_point, last_point))
442
        bounds_x, bounds_y, bounds_z = array.T.tolist()  # type: ignore[misc]
443
        return bounds_x, bounds_y, bounds_z  # type: ignore[return-value]
444
445
    def _parse_single_path(
446
        self,
447
        path: TypePath,
448
        *,
449
        verify: bool = True,
450
    ) -> Path:
451
        if isinstance(path, (torch.Tensor, np.ndarray)):
452
            class_name = self.__class__.__name__
453
            message = (
454
                'Expected type str or Path but found a tensor/array. Instead of'
455
                f' {class_name}(your_tensor),'
456
                f' use {class_name}(tensor=your_tensor).'
457
            )
458
            raise TypeError(message)
459
        try:
460
            path = Path(path).expanduser()
461
        except TypeError as err:
462
            message = (
463
                f'Expected type str or Path but found an object with type'
464
                f' {type(path)} instead'
465
            )
466
            raise TypeError(message) from err
467
        except RuntimeError as err:
468
            message = f'Conversion to path not possible for variable: {path}'
469
            raise RuntimeError(message) from err
470
        if not verify:
471
            return path
472
473
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
474
            raise FileNotFoundError(f'File not found: "{path}"')
475
        return path
476
477
    def _parse_path(
478
        self,
479
        path: TypePath | Sequence[TypePath] | None,
480
        *,
481
        verify: bool = True,
482
    ) -> Path | list[Path] | None:
483
        if path is None:
484
            return None
485
        elif isinstance(path, dict):
486
            # https://github.com/TorchIO-project/torchio/pull/838
487
            raise TypeError('The path argument cannot be a dictionary')
488
        elif self._is_paths_sequence(path):
489
            return [self._parse_single_path(p, verify=verify) for p in path]  # type: ignore[union-attr]
490
        else:
491
            return self._parse_single_path(path, verify=verify)  # type: ignore[arg-type]
492
493
    def _parse_tensor(
494
        self,
495
        tensor: TypeData | None,
496
        none_ok: bool = True,
497
    ) -> torch.Tensor | None:
498
        if tensor is None:
499
            if none_ok:
500
                return None
501
            else:
502
                raise RuntimeError('Input tensor cannot be None')
503
        if isinstance(tensor, np.ndarray):
504
            tensor = check_uint_to_int(tensor)
505
            tensor = torch.as_tensor(tensor)
506
        elif not isinstance(tensor, torch.Tensor):
507
            message = (
508
                'Input tensor must be a PyTorch tensor or NumPy array,'
509
                f' but type "{type(tensor)}" was found'
510
            )
511
            raise TypeError(message)
512
        ndim = tensor.ndim
513
        if ndim != 4:
514
            raise ValueError(f'Input tensor must be 4D, but it is {ndim}D')
515
        if tensor.dtype == torch.bool:
516
            tensor = tensor.to(torch.uint8)
517
        if self.check_nans and torch.isnan(tensor).any():
518
            warnings.warn('NaNs found in tensor', RuntimeWarning, stacklevel=2)
519
        return tensor
520
521
    @staticmethod
522
    def _parse_tensor_shape(tensor: torch.Tensor) -> TypeData:
523
        return ensure_4d(tensor)
524
525
    @staticmethod
526
    def _parse_affine(affine: TypeData | None) -> np.ndarray:
527
        if affine is None:
528
            return np.eye(4)
529
        if isinstance(affine, torch.Tensor):
530
            affine = affine.numpy()
531
        if not isinstance(affine, np.ndarray):
532
            bad_type = type(affine)
533
            raise TypeError(f'Affine must be a NumPy array, not {bad_type}')
534
        if affine.shape != (4, 4):
535
            bad_shape = affine.shape
536
            raise ValueError(f'Affine shape must be (4, 4), not {bad_shape}')
537
        return affine.astype(np.float64)
538
539
    @staticmethod
540
    def _is_paths_sequence(path: TypePath | Sequence[TypePath] | None) -> bool:
541
        is_not_string = not isinstance(path, str)
542
        return is_not_string and is_iterable(path)
543
544
    def _is_multipath(self) -> bool:
545
        return self._is_paths_sequence(self.path)
546
547
    def _is_dir(self) -> bool:
548
        is_sequence = self._is_multipath()
549
        if is_sequence:
550
            return False
551
        elif self.path is None:
552
            return False
553
        else:
554
            assert isinstance(self.path, Path)
555
            return self.path.is_dir()
556
557
    def load(self) -> None:
558
        r"""Load the image from disk.
559
560
        Returns:
561
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
562
            :math:`4 \times 4` affine matrix to convert voxel indices to world
563
            coordinates.
564
        """
565
        if self._loaded:
566
            return
567
568
        paths: list[Path]
569
        if self._is_multipath():
570
            paths = self.path  # type: ignore[assignment]
571
        else:
572
            paths = [self.path]  # type: ignore[list-item]
573
        tensor, affine = self.read_and_check(paths[0])
574
        tensors = [tensor]
575
        for path in paths[1:]:
576
            new_tensor, new_affine = self.read_and_check(path)
577
            if not np.array_equal(affine, new_affine):
578
                message = (
579
                    'Files have different affine matrices.'
580
                    f'\nMatrix of {paths[0]}:'
581
                    f'\n{affine}'
582
                    f'\nMatrix of {path}:'
583
                    f'\n{new_affine}'
584
                )
585
                warnings.warn(message, RuntimeWarning, stacklevel=2)
586
            if not tensor.shape[1:] == new_tensor.shape[1:]:
587
                message = (
588
                    f'Files shape do not match, found {tensor.shape}'
589
                    f'and {new_tensor.shape}'
590
                )
591
                RuntimeError(message)
592
            tensors.append(new_tensor)
593
        tensor = torch.cat(tensors)
594
        self.set_data(tensor)
595
        self.affine = affine
596
        self._loaded = True
597
598
    def unload(self) -> None:
599
        """Unload the image from memory.
600
601
        Raises:
602
            RuntimeError: If the images has not been loaded yet or if no path
603
                is available.
604
        """
605
        if not self._loaded:
606
            message = 'Image cannot be unloaded as it has not been loaded yet'
607
            raise RuntimeError(message)
608
        if self.path is None:
609
            message = (
610
                'Cannot unload image as no path is available'
611
                ' from where the image could be loaded again'
612
            )
613
            raise RuntimeError(message)
614
        self[DATA] = None
615
        self[AFFINE] = None
616
        self._loaded = False
617
618
    def read_and_check(self, path: TypePath) -> TypeDataAffine:
619
        tensor, affine = self.reader(path)
620
        # Make sure the data type is compatible with PyTorch
621
        if self.reader is not read_image and isinstance(tensor, np.ndarray):
622
            tensor = check_uint_to_int(tensor)
623
        tensor = self._parse_tensor_shape(tensor)  # type: ignore[assignment]
624
        tensor = self._parse_tensor(tensor)  # type: ignore[assignment]
625
        affine = self._parse_affine(affine)
626
        if self.check_nans and torch.isnan(tensor).any():
627
            warnings.warn(
628
                f'NaNs found in file "{path}"',
629
                RuntimeWarning,
630
                stacklevel=2,
631
            )
632
        return tensor, affine
633
634
    def save(self, path: TypePath, squeeze: bool | None = None) -> None:
635
        """Save image to disk.
636
637
        Args:
638
            path: String or instance of :class:`pathlib.Path`.
639
            squeeze: Whether to remove singleton dimensions before saving.
640
                If ``None``, the array will be squeezed if the output format is
641
                JP(E)G, PNG, BMP or TIF(F).
642
        """
643
        write_image(
644
            self.data,
645
            self.affine,
646
            path,
647
            squeeze=squeeze,
648
        )
649
650
    def is_2d(self) -> bool:
651
        return self.shape[-1] == 1
652
653
    def numpy(self) -> np.ndarray:
654
        """Get a NumPy array containing the image data."""
655
        return np.asarray(self)
656
657
    def as_sitk(self, **kwargs) -> sitk.Image:
658
        """Get the image as an instance of :class:`sitk.Image`."""
659
        return nib_to_sitk(self.data, self.affine, **kwargs)
660
661
    @classmethod
662
    def from_sitk(cls, sitk_image):
663
        """Instantiate a new TorchIO image from a :class:`sitk.Image`.
664
665
        Example:
666
            >>> import torchio as tio
667
            >>> import SimpleITK as sitk
668
            >>> sitk_image = sitk.Image(20, 30, 40, sitk.sitkUInt16)
669
            >>> tio.LabelMap.from_sitk(sitk_image)
670
            LabelMap(shape: (1, 20, 30, 40); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 93.8 KiB; dtype: torch.IntTensor)
671
            >>> sitk_image = sitk.Image((224, 224), sitk.sitkVectorFloat32, 3)
672
            >>> tio.ScalarImage.from_sitk(sitk_image)
673
            ScalarImage(shape: (3, 224, 224, 1); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 588.0 KiB; dtype: torch.FloatTensor)
674
        """
675
        tensor, affine = sitk_to_nib(sitk_image)
676
        return cls(tensor=tensor, affine=affine)
677
678
    def as_pil(self, transpose=True):
679
        """Get the image as an instance of :class:`PIL.Image`.
680
681
        .. note:: Values will be clamped to 0-255 and cast to uint8.
682
683
        .. note:: To use this method, Pillow needs to be installed:
684
            ``pip install Pillow``.
685
        """
686
        try:
687
            from PIL import Image as ImagePIL
688
        except ModuleNotFoundError as e:
689
            message = 'Please install Pillow to use Image.as_pil(): pip install Pillow'
690
            raise RuntimeError(message) from e
691
692
        self.check_is_2d()
693
        tensor = self.data
694
        if len(tensor) not in (1, 3, 4):
695
            raise NotImplementedError(
696
                'Only 1, 3 or 4 channels are supported for conversion to Pillow image'
697
            )
698
        if len(tensor) == 1:
699
            tensor = torch.cat(3 * [tensor])
700
        if transpose:
701
            tensor = tensor.permute(3, 2, 1, 0)
702
        else:
703
            tensor = tensor.permute(3, 1, 2, 0)
704
        array = tensor.clamp(0, 255).numpy()[0]
705
        return ImagePIL.fromarray(array.astype(np.uint8))
706
707
    def to_gif(
708
        self,
709
        axis: int,
710
        duration: float,  # of full gif
711
        output_path: TypePath,
712
        loop: int = 0,
713
        rescale: bool = True,
714
        optimize: bool = True,
715
        reverse: bool = False,
716
    ) -> None:
717
        """Save an animated GIF of the image.
718
719
        Args:
720
            axis: Spatial axis (0, 1 or 2).
721
            duration: Duration of the full animation in seconds.
722
            output_path: Path to the output GIF file.
723
            loop: Number of times the GIF should loop.
724
                ``0`` means that it will loop forever.
725
            rescale: Use :class:`~torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity`
726
                to rescale the intensity values to :math:`[0, 255]`.
727
            optimize: If ``True``, attempt to compress the palette by
728
                eliminating unused colors. This is only useful if the palette
729
                can be compressed to the next smaller power of 2 elements.
730
            reverse: Reverse the temporal order of frames.
731
        """
732
        from ..visualization import make_gif  # avoid circular import
733
734
        make_gif(
735
            self.data,
736
            axis,
737
            duration,
738
            output_path,
739
            loop=loop,
740
            rescale=rescale,
741
            optimize=optimize,
742
            reverse=reverse,
743
        )
744
745
    def to_ras(self) -> Image:
746
        if self.orientation != tuple('RAS'):
747
            from ..transforms.preprocessing.spatial.to_canonical import ToCanonical
748
749
            return ToCanonical()(self)
750
        return self
751
752
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
753
        """Get image center in RAS+ or LPS+ coordinates.
754
755
        Args:
756
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
757
                the first dimension grows towards the left, etc. Otherwise, the
758
                coordinates will be in RAS+ orientation.
759
        """
760
        size = np.array(self.spatial_shape)
761
        center_index = (size - 1) / 2
762
        r, a, s = apply_affine(self.affine, center_index)
763
        if lps:
764
            return (-r, -a, s)
765
        else:
766
            return (r, a, s)
767
768
    def set_check_nans(self, check_nans: bool) -> None:
769
        self.check_nans = check_nans
770
771
    def plot(self, **kwargs) -> None:
772
        """Plot image."""
773
        if self.is_2d():
774
            self.as_pil().show()
775
        else:
776
            from ..visualization import plot_volume  # avoid circular import
777
778
            plot_volume(self, **kwargs)
779
780
    def show(self, viewer_path: TypePath | None = None) -> None:
781
        """Open the image using external software.
782
783
        Args:
784
            viewer_path: Path to the application used to view the image. If
785
                ``None``, the value of the environment variable
786
                ``SITK_SHOW_COMMAND`` will be used. If this variable is also
787
                not set, TorchIO will try to guess the location of
788
                `ITK-SNAP <http://www.itksnap.org/pmwiki/pmwiki.php>`_ and
789
                `3D Slicer <https://www.slicer.org/>`_.
790
791
        Raises:
792
            RuntimeError: If the viewer is not found.
793
        """
794
        sitk_image = self.as_sitk()
795
        image_viewer = sitk.ImageViewer()
796
        # This is so that 3D Slicer creates segmentation nodes from label maps
797
        if self.__class__.__name__ == 'LabelMap':
798
            image_viewer.SetFileExtension('.seg.nrrd')
799
        if viewer_path is not None:
800
            image_viewer.SetApplication(str(viewer_path))
801
        try:
802
            image_viewer.Execute(sitk_image)
803
        except RuntimeError as e:
804
            viewer_path = guess_external_viewer()
805
            if viewer_path is None:
806
                message = (
807
                    'No external viewer has been found. Please set the'
808
                    ' environment variable SITK_SHOW_COMMAND to a viewer of'
809
                    ' your choice'
810
                )
811
                raise RuntimeError(message) from e
812
            image_viewer.SetApplication(str(viewer_path))
813
            image_viewer.Execute(sitk_image)
814
815
    def _crop_from_slices(
816
        self,
817
        slices: TypeSlice | tuple[TypeSlice, ...],
818
    ) -> Image:
819
        from ..transforms import Crop
820
821
        slices_tuple = to_tuple(slices)  # type: ignore[assignment]
822
        cropping: list[int] = []
823
        for dim, slice_ in enumerate(slices_tuple):
824
            if isinstance(slice_, slice):
825
                pass
826
            elif slice_ is Ellipsis:
827
                message = 'Ellipsis slicing is not supported yet'
828
                raise NotImplementedError(message)
829
            elif isinstance(slice_, int):
830
                slice_ = slice(slice_, slice_ + 1)  # type: ignore[assignment]
831
            else:
832
                message = f'Slice type not understood: "{type(slice_)}"'
833
                raise TypeError(message)
834
            shape_dim = self.spatial_shape[dim]
835
            assert isinstance(slice_, slice)
836
            start, stop, step = slice_.indices(shape_dim)
837
            if step != 1:
838
                message = (
839
                    'Slicing with steps different from 1 is not supported yet.'
840
                    ' Use the Crop transform instead'
841
                )
842
                raise ValueError(message)
843
            crop_ini = start
844
            crop_fin = shape_dim - stop
845
            cropping.extend([crop_ini, crop_fin])
846
        while dim < 2:
0 ignored issues
show
introduced by
The variable dim does not seem to be defined in case the for loop on line 823 is not entered. Are you sure this can never be the case?
Loading history...
847
            cropping.extend([0, 0])
848
            dim += 1
849
        w_ini, w_fin, h_ini, h_fin, d_ini, d_fin = cropping
850
        cropping_arg = w_ini, w_fin, h_ini, h_fin, d_ini, d_fin  # making mypy happy
851
        return Crop(cropping_arg)(self)  # type: ignore[return-value]
852
853
854
class ScalarImage(Image):
855
    """Image whose pixel values represent scalars.
856
857
    Example:
858
        >>> import torch
859
        >>> import torchio as tio
860
        >>> # Loading from a file
861
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
862
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
863
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
864
        >>> data, affine = image.data, image.affine
865
        >>> affine.shape
866
        (4, 4)
867
        >>> image.data is image[tio.DATA]
868
        True
869
        >>> image.data is image.tensor
870
        True
871
        >>> type(image.data)
872
        torch.Tensor
873
874
    See :class:`~torchio.Image` for more information.
875
    """
876
877
    def __init__(self, *args, **kwargs):
878
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
879
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
880
        kwargs.update({'type': INTENSITY})
881
        super().__init__(*args, **kwargs)
882
883
    def hist(self, **kwargs) -> None:
884
        """Plot histogram."""
885
        from ..visualization import plot_histogram
886
887
        x = self.data.flatten().numpy()
888
        plot_histogram(x, **kwargs)
889
890
    def to_video(
891
        self,
892
        output_path: TypePath,
893
        frame_rate: float | None = 15,
894
        seconds: float | None = None,
895
        direction: str = 'I',
896
        verbosity: str = 'error',
897
    ) -> None:
898
        """Create a video showing all image slices along a specified direction.
899
900
        Args:
901
            output_path: Path to the output video file.
902
            frame_rate: Number of frames per second (FPS).
903
            seconds: Target duration of the full video.
904
            direction:
905
            verbosity:
906
907
        .. note:: Only ``frame_rate`` or ``seconds`` may (and must) be specified.
908
        """
909
        from ..visualization import make_video  # avoid circular import
910
911
        make_video(
912
            self.to_ras(),  # type: ignore[arg-type]
913
            output_path,
914
            frame_rate=frame_rate,
915
            seconds=seconds,
916
            direction=direction,
917
            verbosity=verbosity,
918
        )
919
920
921
class LabelMap(Image):
922
    """Image whose pixel values represent segmentation labels.
923
924
    A sequence of paths to 3D images can be passed to create a 4D image.
925
    This is useful to create a
926
    `tissue probability map (TPM) <https://andysbrainbook.readthedocs.io/en/latest/SPM/SPM_Short_Course/SPM_04_Preprocessing/04_SPM_Segmentation.html#tissue-probability-maps>`,
927
    which contains the probability of each voxel belonging to a certain tissue type,
928
    or to create a label map with overlapping labels.
929
930
    Intensity transforms are not applied to these images.
931
932
    Nearest neighbor interpolation is always used to resample label maps,
933
    independently of the specified interpolation type in the transform
934
    instantiation.
935
936
    Example:
937
        >>> import torch
938
        >>> import torchio as tio
939
        >>> binary_tensor = torch.rand(1, 128, 128, 68) > 0.5
940
        >>> label_map = tio.LabelMap(tensor=binary_tensor)  # from a tensor
941
        >>> label_map = tio.LabelMap('t1_seg.nii.gz')  # from a file
942
        >>> # Create a 4D tissue probability map from different 3D images
943
        >>> tissues = 'gray_matter.nii.gz', 'white_matter.nii.gz', 'csf.nii.gz'
944
        >>> tpm = tio.LabelMap(tissues)
945
946
    See :class:`~torchio.Image` for more information.
947
    """
948
949
    def __init__(self, *args, **kwargs):
950
        if 'type' in kwargs and kwargs['type'] != LABEL:
951
            raise ValueError('Type of LabelMap is always torchio.LABEL')
952
        kwargs.update({'type': LABEL})
953
        super().__init__(*args, **kwargs)
954
955
    def count_nonzero(self) -> int:
956
        """Get the number of voxels that are not 0."""
957
        return int(self.data.count_nonzero().item())
958
959
    def count_labels(self) -> dict[int, int]:
960
        """Get the number of voxels in each label."""
961
        values_list = self.data.flatten().tolist()
962
        counter = Counter(values_list)
963
        counts = {label: counter[label] for label in sorted(counter)}
964
        return counts
965