Passed
Pull Request — main (#1400)
by
unknown
01:35
created

torchio.data.image.Image.numpy()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
from __future__ import annotations
2
3
import warnings
4
from collections import Counter
5
from collections.abc import Callable
6
from collections.abc import Sequence
7
from pathlib import Path
8
from typing import Any
9
10
import humanize
11
import nibabel as nib
12
import numpy as np
13
import SimpleITK as sitk
14
import torch
15
from deprecated import deprecated
16
from nibabel.affines import apply_affine
17
18
from ..constants import AFFINE
19
from ..constants import DATA
20
from ..constants import INTENSITY
21
from ..constants import LABEL
22
from ..constants import PATH
23
from ..constants import STEM
24
from ..constants import TENSOR
25
from ..constants import TYPE
26
from ..types import TypeData
27
from ..types import TypeDataAffine
28
from ..types import TypeDirection3D
29
from ..types import TypePath
30
from ..types import TypeQuartetInt
31
from ..types import TypeSlice
32
from ..types import TypeTripletFloat
33
from ..types import TypeTripletInt
34
from ..utils import get_stem
35
from ..utils import guess_external_viewer
36
from ..utils import is_iterable
37
from ..utils import to_tuple
38
from .io import check_uint_to_int
39
from .io import ensure_4d
40
from .io import get_rotation_and_spacing_from_affine
41
from .io import get_sitk_metadata_from_ras_affine
42
from .io import nib_to_sitk
43
from .io import read_affine
44
from .io import read_image
45
from .io import read_shape
46
from .io import sitk_to_nib
47
from .io import write_image
48
49
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
50
TypeBound = tuple[float, float]
51
TypeBounds = tuple[TypeBound, TypeBound, TypeBound]
52
53
deprecation_message = (
54
    'Setting the image data with the property setter is deprecated. Use the'
55
    ' set_data() method instead'
56
)
57
58
59
class Image(dict):
60
    r"""TorchIO image.
61
62
    For information about medical image orientation, check out `NiBabel docs`_,
63
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
64
    `SimpleITK docs`_.
65
66
    Args:
67
        path: Path to a file or sequence of paths to files that can be read by
68
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
69
            DICOM files. If :attr:`tensor` is given, the data in
70
            :attr:`path` will not be read.
71
            If a sequence of paths is given, data
72
            will be concatenated on the channel dimension so spatial
73
            dimensions must match.
74
        type: Type of image, such as :attr:`torchio.INTENSITY` or
75
            :attr:`torchio.LABEL`. This will be used by the transforms to
76
            decide whether to apply an operation, or which interpolation to use
77
            when resampling. For example, `preprocessing`_ and `augmentation`_
78
            intensity transforms will only be applied to images with type
79
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
80
            all types, and nearest neighbor interpolation is always used to
81
            resample images with type :attr:`torchio.LABEL`.
82
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
83
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
84
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
85
            :class:`torch.Tensor` or NumPy array with dimensions
86
            :math:`(C, W, H, D)`.
87
        affine: :math:`4 \times 4` matrix to convert voxel coordinates to world
88
            coordinates. If ``None``, an identity matrix will be used. See the
89
            `NiBabel docs on coordinates`_ for more information.
90
        check_nans: If ``True``, issues a warning if NaNs are found
91
            in the image. If ``False``, images will not be checked for the
92
            presence of NaNs.
93
        reader: Callable object that takes a path and returns a 4D tensor and a
94
            2D, :math:`4 \times 4` affine matrix. This can be used if your data
95
            is saved in a custom format, such as ``.npy`` (see example below).
96
            If the affine matrix is ``None``, an identity matrix will be used.
97
        **kwargs: Items that will be added to the image dictionary, e.g.
98
            acquisition parameters or image ID.
99
        verify_path: If ``True``, the path will be checked to see if it exists. If
100
            ``False``, the path will not be verified. This is useful when it is
101
            expensive to check the path, e.g., when reading a large dataset from a
102
            mounted drive.
103
104
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
105
    when needed.
106
107
    Example:
108
        >>> import torchio as tio
109
        >>> import numpy as np
110
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
111
        >>> image  # not loaded yet
112
        ScalarImage(path: t1.nii.gz; type: intensity)
113
        >>> times_two = 2 * image.data  # data is loaded and cached here
114
        >>> image
115
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
116
        >>> image.save('doubled_image.nii.gz')
117
        >>> def numpy_reader(path):
118
        ...     data = np.load(path).as_type(np.float32)
119
        ...     affine = np.eye(4)
120
        ...     return data, affine
121
        >>> image = tio.ScalarImage('t1.npy', reader=numpy_reader)
122
123
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
124
    .. _preprocessing: https://docs.torchio.org/transforms/preprocessing.html#intensity
125
    .. _augmentation: https://docs.torchio.org/transforms/augmentation.html#intensity
126
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
127
    .. _NiBabel docs on coordinates: https://nipy.org/nibabel/coordinate_systems.html#the-affine-matrix-as-a-transformation-between-spaces
128
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
129
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
130
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
131
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
132
    """
133
134
    def __init__(
135
        self,
136
        path: TypePath | Sequence[TypePath] | None = None,
137
        type: str | None = None,  # noqa: A002
138
        tensor: TypeData | None = None,
139
        affine: TypeData | None = None,
140
        check_nans: bool = False,  # removed by ITK by default
141
        reader: Callable[[TypePath], TypeDataAffine] = read_image,
142
        verify_path: bool = True,
143
        **kwargs: dict[str, Any],
144
    ):
145
        self.check_nans = check_nans
146
        self.reader = reader
147
148
        if type is None:
149
            warnings.warn(
150
                'Not specifying the image type is deprecated and will be'
151
                ' mandatory in the future. You can probably use'
152
                ' tio.ScalarImage or tio.LabelMap instead',
153
                FutureWarning,
154
                stacklevel=2,
155
            )
156
            type = INTENSITY  # noqa: A001
157
158
        if path is None and tensor is None:
159
            raise ValueError('A value for path or tensor must be given')
160
        self._loaded = False
161
162
        tensor = self._parse_tensor(tensor)
163
        affine = self._parse_affine(affine)
164
        if tensor is not None:
165
            self.set_data(tensor)
166
            self.affine = affine
167
            self._loaded = True
168
        for key in PROTECTED_KEYS:
169
            if key in kwargs:
170
                message = f'Key "{key}" is reserved. Use a different one'
171
                raise ValueError(message)
172
        if 'channels_last' in kwargs:
173
            message = (
174
                'The "channels_last" keyword argument is deprecated after'
175
                ' https://github.com/TorchIO-project/torchio/pull/685 and will be'
176
                ' removed in the future'
177
            )
178
            warnings.warn(message, FutureWarning, stacklevel=2)
179
180
        super().__init__(**kwargs)
181
        self.path = self._parse_path(path, verify=verify_path)
182
183
        self[PATH] = '' if self.path is None else str(self.path)
184
        self[STEM] = '' if self.path is None else get_stem(self.path)
185
        self[TYPE] = type
186
187
    def __repr__(self):
188
        properties = []
189
        properties.extend(
190
            [
191
                f'shape: {self.shape}',
192
                f'spacing: {self.get_spacing_string()}',
193
                f'orientation: {self.orientation_str}+',
194
            ]
195
        )
196
        if self._loaded:
197
            properties.append(f'dtype: {self.data.type()}')
198
            natural = humanize.naturalsize(self.memory, binary=True)
199
            properties.append(f'memory: {natural}')
200
        else:
201
            properties.append(f'path: "{self.path}"')
202
203
        properties = '; '.join(properties)
204
        string = f'{self.__class__.__name__}({properties})'
205
        return string
206
207
    def __getitem__(self, item):
208
        if isinstance(item, (slice, int, tuple)):
209
            return self._crop_from_slices(item)
210
211
        if item in (DATA, AFFINE):
212
            if item not in self:
213
                self.load()
214
        return super().__getitem__(item)
215
216
    def __array__(self):
217
        return self.data.numpy()
218
219
    def __copy__(self):
220
        kwargs = {
221
            TYPE: self.type,
222
            PATH: self.path,
223
        }
224
        if self._loaded:
225
            kwargs[TENSOR] = self.data
226
            kwargs[AFFINE] = self.affine
227
        for key, value in self.items():
228
            if key in PROTECTED_KEYS:
229
                continue
230
            kwargs[key] = value  # should I copy? deepcopy?
231
        new_image_class = type(self)
232
        new_image = new_image_class(
233
            check_nans=self.check_nans,
234
            reader=self.reader,
235
            **kwargs,
236
        )
237
        return new_image
238
239
    @property
240
    def data(self) -> torch.Tensor:
241
        """Tensor data (same as :class:`Image.tensor`)."""
242
        return self[DATA]
243
244
    @data.setter  # type: ignore[misc]
245
    @deprecated(version='0.18.16', reason=deprecation_message)
246
    def data(self, tensor: TypeData):
247
        self.set_data(tensor)
248
249
    def set_data(self, tensor: TypeData):
250
        """Store a 4D tensor in the :attr:`data` key and attribute.
251
252
        Args:
253
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
254
        """
255
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
256
        self._loaded = True
257
258
    @property
259
    def tensor(self) -> torch.Tensor:
260
        """Tensor data (same as :class:`Image.data`)."""
261
        return self.data
262
263
    @property
264
    def affine(self) -> np.ndarray:
265
        """Affine matrix to transform voxel indices into world coordinates."""
266
        # If path is a dir (probably DICOM), just load the data
267
        # Same if it's a list of paths (used to create a 4D image)
268
        # Finally, if we use a custom reader, SimpleITK probably won't be able
269
        # to read the metadata, so we resort to loading everything into memory
270
        is_custom_reader = self.reader is not read_image
271
        if self._loaded or self._is_dir() or self._is_multipath() or is_custom_reader:
272
            affine = self[AFFINE]
273
        else:
274
            assert self.path is not None
275
            assert isinstance(self.path, (str, Path))
276
            affine = read_affine(self.path)
277
        return affine
278
279
    @affine.setter
280
    def affine(self, matrix):
281
        self[AFFINE] = self._parse_affine(matrix)
282
283
    @property
284
    def type(self) -> str:  # noqa: A003
285
        return self[TYPE]
286
287
    @property
288
    def shape(self) -> TypeQuartetInt:
289
        """Tensor shape as :math:`(C, W, H, D)`."""
290
        custom_reader = self.reader is not read_image
291
        multipath = self._is_multipath()
292
        if isinstance(self.path, Path):
293
            is_dir = self.path.is_dir()
294
        shape: TypeQuartetInt
295
        if self._loaded or custom_reader or multipath or is_dir:
0 ignored issues
show
introduced by
The variable is_dir does not seem to be defined in case isinstance(self.path, Path) on line 292 is False. Are you sure this can never be the case?
Loading history...
296
            channels, si, sj, sk = self.data.shape
297
            shape = channels, si, sj, sk
298
        else:
299
            assert isinstance(self.path, (str, Path))
300
            shape = read_shape(self.path)
301
        return shape
302
303
    @property
304
    def spatial_shape(self) -> TypeTripletInt:
305
        """Tensor spatial shape as :math:`(W, H, D)`."""
306
        return self.shape[1:]
307
308
    def check_is_2d(self) -> None:
309
        if not self.is_2d():
310
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
311
            raise RuntimeError(message)
312
313
    @property
314
    def height(self) -> int:
315
        """Image height, if 2D."""
316
        self.check_is_2d()
317
        return self.spatial_shape[1]
318
319
    @property
320
    def width(self) -> int:
321
        """Image width, if 2D."""
322
        self.check_is_2d()
323
        return self.spatial_shape[0]
324
325
    @property
326
    def orientation(self) -> tuple[str, str, str]:
327
        """Orientation codes."""
328
        return nib.orientations.aff2axcodes(self.affine)
329
330
    @property
331
    def orientation_str(self) -> str:
332
        """Orientation as a string."""
333
        return ''.join(self.orientation)
334
335
    @property
336
    def direction(self) -> TypeDirection3D:
337
        _, _, direction = get_sitk_metadata_from_ras_affine(
338
            self.affine,
339
            lps=False,
340
        )
341
        return direction  # type: ignore[return-value]
342
343
    @property
344
    def spacing(self) -> tuple[float, float, float]:
345
        """Voxel spacing in mm."""
346
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
347
        sx, sy, sz = spacing
348
        return float(sx), float(sy), float(sz)
349
350
    @property
351
    def origin(self) -> tuple[float, float, float]:
352
        """Center of first voxel in array, in mm."""
353
        ox, oy, oz = self.affine[:3, 3]
354
        return ox, oy, oz
355
356
    @property
357
    def itemsize(self):
358
        """Element size of the data type."""
359
        return self.data.element_size()
360
361
    @property
362
    def memory(self) -> float:
363
        """Number of Bytes that the tensor takes in the RAM."""
364
        return np.prod(self.shape) * self.itemsize
365
366
    @property
367
    def bounds(self) -> np.ndarray:
368
        """Position of centers of voxels in smallest and largest indices."""
369
        ini = 0, 0, 0
370
        fin = np.array(self.spatial_shape) - 1
371
        point_ini = apply_affine(self.affine, ini)
372
        point_fin = apply_affine(self.affine, fin)
373
        return np.array((point_ini, point_fin))
374
375
    @property
376
    def num_channels(self) -> int:
377
        """Get the number of channels in the associated 4D tensor."""
378
        return len(self.data)
379
380
    def axis_name_to_index(self, axis: str) -> int:
381
        """Convert an axis name to an axis index.
382
383
        Args:
384
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
385
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
386
                versions and first letters are also valid, as only the first
387
                letter will be used.
388
389
        .. note:: If you are working with animals, you should probably use
390
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
391
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
392
            respectively.
393
394
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
395
            ``'Left'`` and ``'Right'``.
396
        """
397
        # Top and bottom are used for the vertical 2D axis as the use of
398
        # Height vs Horizontal might be ambiguous
399
400
        if not isinstance(axis, str):
401
            raise ValueError('Axis must be a string')
402
        axis = axis[0].upper()
403
404
        # Generally, TorchIO tensors are (C, W, H, D)
405
        if axis in 'TB':  # Top, Bottom
406
            return -2
407
        else:
408
            try:
409
                index = self.orientation.index(axis)
410
            except ValueError:
411
                index = self.orientation.index(self.flip_axis(axis))
412
            # Return negative indices so that it does not matter whether we
413
            # refer to spatial dimensions or not
414
            index = -3 + index
415
            return index
416
417
    @staticmethod
418
    def flip_axis(axis: str) -> str:
419
        """Return the opposite axis label. For example, ``'L'`` -> ``'R'``.
420
421
        Args:
422
            axis: Axis label, such as ``'L'`` or ``'left'``.
423
        """
424
        labels = 'LRPAISTBDV'
425
        first = labels[::2]
426
        last = labels[1::2]
427
        flip_dict = dict(zip(first + last, last + first, strict=True))
428
        axis = axis[0].upper()
429
        flipped_axis = flip_dict.get(axis)
430
        if flipped_axis is None:
431
            values = ', '.join(labels)
432
            message = f'Axis not understood. Please use one of: {values}'
433
            raise ValueError(message)
434
        return flipped_axis
435
436
    def get_spacing_string(self) -> str:
437
        strings = [f'{n:.2f}' for n in self.spacing]
438
        string = f'({", ".join(strings)})'
439
        return string
440
441
    def get_bounds(self) -> TypeBounds:
442
        """Get minimum and maximum world coordinates occupied by the image."""
443
        first_index = 3 * (-0.5,)
444
        last_index = np.array(self.spatial_shape) - 0.5
445
        first_point = apply_affine(self.affine, first_index)
446
        last_point = apply_affine(self.affine, last_index)
447
        array = np.array((first_point, last_point))
448
        bounds_x, bounds_y, bounds_z = array.T.tolist()  # type: ignore[misc]
449
        return bounds_x, bounds_y, bounds_z  # type: ignore[return-value]
450
451
    def _parse_single_path(
452
        self,
453
        path: TypePath,
454
        *,
455
        verify: bool = True,
456
    ) -> Path:
457
        if isinstance(path, (torch.Tensor, np.ndarray)):
458
            class_name = self.__class__.__name__
459
            message = (
460
                'Expected type str or Path but found a tensor/array. Instead of'
461
                f' {class_name}(your_tensor),'
462
                f' use {class_name}(tensor=your_tensor).'
463
            )
464
            raise TypeError(message)
465
        try:
466
            path = Path(path).expanduser()
467
        except TypeError as err:
468
            message = (
469
                f'Expected type str or Path but found an object with type'
470
                f' {type(path)} instead'
471
            )
472
            raise TypeError(message) from err
473
        except RuntimeError as err:
474
            message = f'Conversion to path not possible for variable: {path}'
475
            raise RuntimeError(message) from err
476
        if not verify:
477
            return path
478
479
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
480
            raise FileNotFoundError(f'File not found: "{path}"')
481
        return path
482
483
    def _parse_path(
484
        self,
485
        path: TypePath | Sequence[TypePath] | None,
486
        *,
487
        verify: bool = True,
488
    ) -> Path | list[Path] | None:
489
        if path is None:
490
            return None
491
        elif isinstance(path, dict):
492
            # https://github.com/TorchIO-project/torchio/pull/838
493
            raise TypeError('The path argument cannot be a dictionary')
494
        elif self._is_paths_sequence(path):
495
            return [self._parse_single_path(p, verify=verify) for p in path]  # type: ignore[union-attr]
496
        else:
497
            return self._parse_single_path(path, verify=verify)  # type: ignore[arg-type]
498
499
    def _parse_tensor(
500
        self,
501
        tensor: TypeData | None,
502
        none_ok: bool = True,
503
    ) -> torch.Tensor | None:
504
        if tensor is None:
505
            if none_ok:
506
                return None
507
            else:
508
                raise RuntimeError('Input tensor cannot be None')
509
        if isinstance(tensor, np.ndarray):
510
            tensor = check_uint_to_int(tensor)
511
            tensor = torch.as_tensor(tensor)
512
        elif not isinstance(tensor, torch.Tensor):
513
            message = (
514
                'Input tensor must be a PyTorch tensor or NumPy array,'
515
                f' but type "{type(tensor)}" was found'
516
            )
517
            raise TypeError(message)
518
        ndim = tensor.ndim
519
        if ndim != 4:
520
            raise ValueError(f'Input tensor must be 4D, but it is {ndim}D')
521
        if tensor.dtype == torch.bool:
522
            tensor = tensor.to(torch.uint8)
523
        if self.check_nans and torch.isnan(tensor).any():
524
            warnings.warn('NaNs found in tensor', RuntimeWarning, stacklevel=2)
525
        return tensor
526
527
    @staticmethod
528
    def _parse_tensor_shape(tensor: torch.Tensor) -> TypeData:
529
        return ensure_4d(tensor)
530
531
    @staticmethod
532
    def _parse_affine(affine: TypeData | None) -> np.ndarray:
533
        if affine is None:
534
            return np.eye(4)
535
        if isinstance(affine, torch.Tensor):
536
            affine = affine.numpy()
537
        if not isinstance(affine, np.ndarray):
538
            bad_type = type(affine)
539
            raise TypeError(f'Affine must be a NumPy array, not {bad_type}')
540
        if affine.shape != (4, 4):
541
            bad_shape = affine.shape
542
            raise ValueError(f'Affine shape must be (4, 4), not {bad_shape}')
543
        return affine.astype(np.float64)
544
545
    @staticmethod
546
    def _is_paths_sequence(path: TypePath | Sequence[TypePath] | None) -> bool:
547
        is_not_string = not isinstance(path, str)
548
        return is_not_string and is_iterable(path)
549
550
    def _is_multipath(self) -> bool:
551
        return self._is_paths_sequence(self.path)
552
553
    def _is_dir(self) -> bool:
554
        is_sequence = self._is_multipath()
555
        if is_sequence:
556
            return False
557
        elif self.path is None:
558
            return False
559
        else:
560
            assert isinstance(self.path, Path)
561
            return self.path.is_dir()
562
563
    def load(self) -> None:
564
        r"""Load the image from disk.
565
566
        Returns:
567
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
568
            :math:`4 \times 4` affine matrix to convert voxel indices to world
569
            coordinates.
570
        """
571
        if self._loaded:
572
            return
573
574
        paths: list[Path]
575
        if self._is_multipath():
576
            paths = self.path  # type: ignore[assignment]
577
        else:
578
            paths = [self.path]  # type: ignore[list-item]
579
        tensor, affine = self.read_and_check(paths[0])
580
        tensors = [tensor]
581
        for path in paths[1:]:
582
            new_tensor, new_affine = self.read_and_check(path)
583
            if not np.array_equal(affine, new_affine):
584
                message = (
585
                    'Files have different affine matrices.'
586
                    f'\nMatrix of {paths[0]}:'
587
                    f'\n{affine}'
588
                    f'\nMatrix of {path}:'
589
                    f'\n{new_affine}'
590
                )
591
                warnings.warn(message, RuntimeWarning, stacklevel=2)
592
            if not tensor.shape[1:] == new_tensor.shape[1:]:
593
                message = (
594
                    f'Files shape do not match, found {tensor.shape}'
595
                    f'and {new_tensor.shape}'
596
                )
597
                raise RuntimeError(message)
598
            tensors.append(new_tensor)
599
        tensor = torch.cat(tensors)
600
        self.set_data(tensor)
601
        self.affine = affine
602
        self._loaded = True
603
604
    def unload(self) -> None:
605
        """Unload the image from memory.
606
607
        Raises:
608
            RuntimeError: If the images has not been loaded yet or if no path
609
                is available.
610
        """
611
        if not self._loaded:
612
            message = 'Image cannot be unloaded as it has not been loaded yet'
613
            raise RuntimeError(message)
614
        if self.path is None:
615
            message = (
616
                'Cannot unload image as no path is available'
617
                ' from where the image could be loaded again'
618
            )
619
            raise RuntimeError(message)
620
        self[DATA] = None
621
        self[AFFINE] = None
622
        self._loaded = False
623
624
    def read_and_check(self, path: TypePath) -> TypeDataAffine:
625
        tensor, affine = self.reader(path)
626
        # Make sure the data type is compatible with PyTorch
627
        if self.reader is not read_image and isinstance(tensor, np.ndarray):
628
            tensor = check_uint_to_int(tensor)
629
        tensor = self._parse_tensor_shape(tensor)  # type: ignore[assignment]
630
        tensor = self._parse_tensor(tensor)  # type: ignore[assignment]
631
        affine = self._parse_affine(affine)
632
        if self.check_nans and torch.isnan(tensor).any():
633
            warnings.warn(
634
                f'NaNs found in file "{path}"',
635
                RuntimeWarning,
636
                stacklevel=2,
637
            )
638
        return tensor, affine
639
640
    def save(self, path: TypePath, squeeze: bool | None = None) -> None:
641
        """Save image to disk.
642
643
        Args:
644
            path: String or instance of :class:`pathlib.Path`.
645
            squeeze: Whether to remove singleton dimensions before saving.
646
                If ``None``, the array will be squeezed if the output format is
647
                JP(E)G, PNG, BMP or TIF(F).
648
        """
649
        write_image(
650
            self.data,
651
            self.affine,
652
            path,
653
            squeeze=squeeze,
654
        )
655
656
    def is_2d(self) -> bool:
657
        return self.shape[-1] == 1
658
659
    def numpy(self) -> np.ndarray:
660
        """Get a NumPy array containing the image data."""
661
        return np.asarray(self)
662
663
    def as_sitk(self, **kwargs) -> sitk.Image:
664
        """Get the image as an instance of :class:`sitk.Image`."""
665
        return nib_to_sitk(self.data, self.affine, **kwargs)
666
667
    @classmethod
668
    def from_sitk(cls, sitk_image):
669
        """Instantiate a new TorchIO image from a :class:`sitk.Image`.
670
671
        Example:
672
            >>> import torchio as tio
673
            >>> import SimpleITK as sitk
674
            >>> sitk_image = sitk.Image(20, 30, 40, sitk.sitkUInt16)
675
            >>> tio.LabelMap.from_sitk(sitk_image)
676
            LabelMap(shape: (1, 20, 30, 40); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 93.8 KiB; dtype: torch.IntTensor)
677
            >>> sitk_image = sitk.Image((224, 224), sitk.sitkVectorFloat32, 3)
678
            >>> tio.ScalarImage.from_sitk(sitk_image)
679
            ScalarImage(shape: (3, 224, 224, 1); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 588.0 KiB; dtype: torch.FloatTensor)
680
        """
681
        tensor, affine = sitk_to_nib(sitk_image)
682
        return cls(tensor=tensor, affine=affine)
683
684
    def as_pil(self, transpose=True):
685
        """Get the image as an instance of :class:`PIL.Image`.
686
687
        .. note:: Values will be clamped to 0-255 and cast to uint8.
688
689
        .. note:: To use this method, Pillow needs to be installed:
690
            ``pip install Pillow``.
691
        """
692
        try:
693
            from PIL import Image as ImagePIL
694
        except ModuleNotFoundError as e:
695
            message = 'Please install Pillow to use Image.as_pil(): pip install Pillow'
696
            raise RuntimeError(message) from e
697
698
        self.check_is_2d()
699
        tensor = self.data
700
        if len(tensor) not in (1, 3, 4):
701
            raise NotImplementedError(
702
                'Only 1, 3 or 4 channels are supported for conversion to Pillow image'
703
            )
704
        if len(tensor) == 1:
705
            tensor = torch.cat(3 * [tensor])
706
        if transpose:
707
            tensor = tensor.permute(3, 2, 1, 0)
708
        else:
709
            tensor = tensor.permute(3, 1, 2, 0)
710
        array = tensor.clamp(0, 255).numpy()[0]
711
        return ImagePIL.fromarray(array.astype(np.uint8))
712
713
    def to_gif(
714
        self,
715
        axis: int,
716
        duration: float,  # of full gif
717
        output_path: TypePath,
718
        loop: int = 0,
719
        rescale: bool = True,
720
        optimize: bool = True,
721
        reverse: bool = False,
722
    ) -> None:
723
        """Save an animated GIF of the image.
724
725
        Args:
726
            axis: Spatial axis (0, 1 or 2).
727
            duration: Duration of the full animation in seconds.
728
            output_path: Path to the output GIF file.
729
            loop: Number of times the GIF should loop.
730
                ``0`` means that it will loop forever.
731
            rescale: Use :class:`~torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity`
732
                to rescale the intensity values to :math:`[0, 255]`.
733
            optimize: If ``True``, attempt to compress the palette by
734
                eliminating unused colors. This is only useful if the palette
735
                can be compressed to the next smaller power of 2 elements.
736
            reverse: Reverse the temporal order of frames.
737
        """
738
        from ..visualization import make_gif  # avoid circular import
739
740
        make_gif(
741
            self.data,
742
            axis,
743
            duration,
744
            output_path,
745
            loop=loop,
746
            rescale=rescale,
747
            optimize=optimize,
748
            reverse=reverse,
749
        )
750
751
    def to_ras(self) -> Image:
752
        if self.orientation_str != 'RAS':
753
            from ..transforms.preprocessing.spatial.to_canonical import ToCanonical
754
755
            return ToCanonical()(self)
756
        return self
757
758
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
759
        """Get image center in RAS+ or LPS+ coordinates.
760
761
        Args:
762
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
763
                the first dimension grows towards the left, etc. Otherwise, the
764
                coordinates will be in RAS+ orientation.
765
        """
766
        size = np.array(self.spatial_shape)
767
        center_index = (size - 1) / 2
768
        r, a, s = apply_affine(self.affine, center_index)
769
        if lps:
770
            return (-r, -a, s)
771
        else:
772
            return (r, a, s)
773
774
    def set_check_nans(self, check_nans: bool) -> None:
775
        self.check_nans = check_nans
776
777
    def new_like(self, tensor: TypeData, affine: TypeData | None = None) -> Image:
778
        """Create a new image of the same type with new tensor data.
779
780
        This method creates a new image instance of the same class as the current
781
        image, preserving essential attributes like type, check_nans, and reader.
782
        This is particularly useful for transforms that need to create new images
783
        while maintaining compatibility with custom Image subclasses.
784
785
        Args:
786
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)` for the new image.
787
            affine: :math:`4 \\times 4` matrix to convert voxel coordinates to world
788
                coordinates. If ``None``, uses the current image's affine matrix.
789
790
        Returns:
791
            A new image instance of the same type as the current image.
792
793
        Example:
794
            >>> import torch
795
            >>> import torchio as tio
796
            >>> # Standard usage
797
            >>> image = tio.ScalarImage('path/to/image.nii.gz')
798
            >>> new_tensor = torch.rand(1, 64, 64, 64)
799
            >>> new_image = image.new_like(tensor=new_tensor)
800
            >>> isinstance(new_image, tio.ScalarImage)
801
            True
802
803
            >>> # Custom subclass usage
804
            >>> class CustomImage(tio.ScalarImage):
805
            ...     def __init__(self, tensor, affine, metadata, **kwargs):
806
            ...         super().__init__(tensor=tensor, affine=affine, **kwargs)
807
            ...         self.metadata = metadata
808
            ...
809
            ...     def new_like(self, tensor, affine=None):
810
            ...         return type(self)(
811
            ...             tensor=tensor,
812
            ...             affine=affine if affine is not None else self.affine,
813
            ...             metadata=self.metadata,  # Preserve custom attribute
814
            ...             check_nans=self.check_nans,
815
            ...             reader=self.reader,
816
            ...         )
817
            >>> custom = CustomImage(torch.rand(1, 32, 32, 32), torch.eye(4), {'id': 123})
818
            >>> new_custom = custom.new_like(torch.rand(1, 16, 16, 16))
819
            >>> new_custom.metadata['id']
820
            123
821
        """
822
        if affine is None:
823
            affine = self.affine
824
825
        # First, try the standard constructor approach
826
        try:
827
            return type(self)(
828
                tensor=tensor,
829
                affine=affine,
830
                type=self.type,
831
                check_nans=self.check_nans,
832
                reader=self.reader,
833
            )
834
        except TypeError:
835
            # If the standard constructor fails (e.g., custom subclass with additional required args),
836
            # fall back to a copy-based approach
837
            import copy
838
839
            new_image = copy.deepcopy(self)
840
            new_image.set_data(tensor)
841
            new_image.affine = affine
842
            return new_image
843
844
    def plot(self, **kwargs) -> None:
845
        """Plot image."""
846
        if self.is_2d():
847
            self.as_pil().show()
848
        else:
849
            from ..visualization import plot_volume  # avoid circular import
850
851
            plot_volume(self, **kwargs)
852
853
    def show(self, viewer_path: TypePath | None = None) -> None:
854
        """Open the image using external software.
855
856
        Args:
857
            viewer_path: Path to the application used to view the image. If
858
                ``None``, the value of the environment variable
859
                ``SITK_SHOW_COMMAND`` will be used. If this variable is also
860
                not set, TorchIO will try to guess the location of
861
                `ITK-SNAP <http://www.itksnap.org/pmwiki/pmwiki.php>`_ and
862
                `3D Slicer <https://www.slicer.org/>`_.
863
864
        Raises:
865
            RuntimeError: If the viewer is not found.
866
        """
867
        sitk_image = self.as_sitk()
868
        image_viewer = sitk.ImageViewer()
869
        # This is so that 3D Slicer creates segmentation nodes from label maps
870
        if self.__class__.__name__ == 'LabelMap':
871
            image_viewer.SetFileExtension('.seg.nrrd')
872
        if viewer_path is not None:
873
            image_viewer.SetApplication(str(viewer_path))
874
        try:
875
            image_viewer.Execute(sitk_image)
876
        except RuntimeError as e:
877
            viewer_path = guess_external_viewer()
878
            if viewer_path is None:
879
                message = (
880
                    'No external viewer has been found. Please set the'
881
                    ' environment variable SITK_SHOW_COMMAND to a viewer of'
882
                    ' your choice'
883
                )
884
                raise RuntimeError(message) from e
885
            image_viewer.SetApplication(str(viewer_path))
886
            image_viewer.Execute(sitk_image)
887
888
    def _crop_from_slices(
889
        self,
890
        slices: TypeSlice | tuple[TypeSlice, ...],
891
    ) -> Image:
892
        from ..transforms import Crop
893
894
        slices_tuple = to_tuple(slices)  # type: ignore[assignment]
895
        cropping: list[int] = []
896
        for dim, slice_ in enumerate(slices_tuple):
897
            if isinstance(slice_, slice):
898
                pass
899
            elif slice_ is Ellipsis:
900
                message = 'Ellipsis slicing is not supported yet'
901
                raise NotImplementedError(message)
902
            elif isinstance(slice_, int):
903
                slice_ = slice(slice_, slice_ + 1)  # type: ignore[assignment]
904
            else:
905
                message = f'Slice type not understood: "{type(slice_)}"'
906
                raise TypeError(message)
907
            shape_dim = self.spatial_shape[dim]
908
            assert isinstance(slice_, slice)
909
            start, stop, step = slice_.indices(shape_dim)
910
            if step != 1:
911
                message = (
912
                    'Slicing with steps different from 1 is not supported yet.'
913
                    ' Use the Crop transform instead'
914
                )
915
                raise ValueError(message)
916
            crop_ini = start
917
            crop_fin = shape_dim - stop
918
            cropping.extend([crop_ini, crop_fin])
919
        while dim < 2:
0 ignored issues
show
introduced by
The variable dim does not seem to be defined in case the for loop on line 896 is not entered. Are you sure this can never be the case?
Loading history...
920
            cropping.extend([0, 0])
921
            dim += 1
922
        w_ini, w_fin, h_ini, h_fin, d_ini, d_fin = cropping
923
        cropping_arg = w_ini, w_fin, h_ini, h_fin, d_ini, d_fin  # making mypy happy
924
        return Crop(cropping_arg)(self)  # type: ignore[return-value]
925
926
927
class ScalarImage(Image):
928
    """Image whose pixel values represent scalars.
929
930
    Example:
931
        >>> import torch
932
        >>> import torchio as tio
933
        >>> # Loading from a file
934
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
935
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
936
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
937
        >>> data, affine = image.data, image.affine
938
        >>> affine.shape
939
        (4, 4)
940
        >>> image.data is image[tio.DATA]
941
        True
942
        >>> image.data is image.tensor
943
        True
944
        >>> type(image.data)
945
        torch.Tensor
946
947
    See :class:`~torchio.Image` for more information.
948
    """
949
950
    def __init__(self, *args, **kwargs):
951
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
952
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
953
        kwargs.update({'type': INTENSITY})
954
        super().__init__(*args, **kwargs)
955
956
    def hist(self, **kwargs) -> None:
957
        """Plot histogram."""
958
        from ..visualization import plot_histogram
959
960
        x = self.data.flatten().numpy()
961
        plot_histogram(x, **kwargs)
962
963
    def to_video(
964
        self,
965
        output_path: TypePath,
966
        frame_rate: float | None = 15,
967
        seconds: float | None = None,
968
        direction: str = 'I',
969
        verbosity: str = 'error',
970
    ) -> None:
971
        """Create a video showing all image slices along a specified direction.
972
973
        Args:
974
            output_path: Path to the output video file.
975
            frame_rate: Number of frames per second (FPS).
976
            seconds: Target duration of the full video.
977
            direction:
978
            verbosity:
979
980
        .. note:: Only ``frame_rate`` or ``seconds`` may (and must) be specified.
981
        """
982
        from ..visualization import make_video  # avoid circular import
983
984
        make_video(
985
            self.to_ras(),  # type: ignore[arg-type]
986
            output_path,
987
            frame_rate=frame_rate,
988
            seconds=seconds,
989
            direction=direction,
990
            verbosity=verbosity,
991
        )
992
993
994
class LabelMap(Image):
995
    """Image whose pixel values represent segmentation labels.
996
997
    A sequence of paths to 3D images can be passed to create a 4D image.
998
    This is useful to create a
999
    `tissue probability map (TPM) <https://andysbrainbook.readthedocs.io/en/latest/SPM/SPM_Short_Course/SPM_04_Preprocessing/04_SPM_Segmentation.html#tissue-probability-maps>`,
1000
    which contains the probability of each voxel belonging to a certain tissue type,
1001
    or to create a label map with overlapping labels.
1002
1003
    Intensity transforms are not applied to these images.
1004
1005
    Nearest neighbor interpolation is always used to resample label maps,
1006
    independently of the specified interpolation type in the transform
1007
    instantiation.
1008
1009
    Example:
1010
        >>> import torch
1011
        >>> import torchio as tio
1012
        >>> binary_tensor = torch.rand(1, 128, 128, 68) > 0.5
1013
        >>> label_map = tio.LabelMap(tensor=binary_tensor)  # from a tensor
1014
        >>> label_map = tio.LabelMap('t1_seg.nii.gz')  # from a file
1015
        >>> # Create a 4D tissue probability map from different 3D images
1016
        >>> tissues = 'gray_matter.nii.gz', 'white_matter.nii.gz', 'csf.nii.gz'
1017
        >>> tpm = tio.LabelMap(tissues)
1018
1019
    See :class:`~torchio.Image` for more information.
1020
    """
1021
1022
    def __init__(self, *args, **kwargs):
1023
        if 'type' in kwargs and kwargs['type'] != LABEL:
1024
            raise ValueError('Type of LabelMap is always torchio.LABEL')
1025
        kwargs.update({'type': LABEL})
1026
        super().__init__(*args, **kwargs)
1027
1028
    def count_nonzero(self) -> int:
1029
        """Get the number of voxels that are not 0."""
1030
        return int(self.data.count_nonzero().item())
1031
1032
    def count_labels(self) -> dict[int, int]:
1033
        """Get the number of voxels in each label."""
1034
        values_list = self.data.flatten().tolist()
1035
        counter = Counter(values_list)
1036
        counts = {label: counter[label] for label in sorted(counter)}
1037
        return counts
1038