Passed
Push — main ( bf7ac6...cabfda )
by Fernando
01:25
created

torchio.data.image.Image._parse_single_path()   B

Complexity

Conditions 6

Size

Total Lines 23
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 19
nop 3
dl 0
loc 23
rs 8.5166
c 0
b 0
f 0
1
from __future__ import annotations
2
3
import warnings
4
from collections import Counter
5
from collections.abc import Sequence
6
from pathlib import Path
7
from typing import Any
8
from typing import Callable
9
10
import humanize
11
import nibabel as nib
12
import numpy as np
13
import SimpleITK as sitk
14
import torch
15
from deprecated import deprecated
16
from nibabel.affines import apply_affine
17
18
from ..constants import AFFINE
19
from ..constants import DATA
20
from ..constants import INTENSITY
21
from ..constants import LABEL
22
from ..constants import PATH
23
from ..constants import STEM
24
from ..constants import TENSOR
25
from ..constants import TYPE
26
from ..types import TypeData
27
from ..types import TypeDataAffine
28
from ..types import TypeDirection3D
29
from ..types import TypePath
30
from ..types import TypeQuartetInt
31
from ..types import TypeSlice
32
from ..types import TypeTripletFloat
33
from ..types import TypeTripletInt
34
from ..utils import get_stem
35
from ..utils import guess_external_viewer
36
from ..utils import in_torch_loader
37
from ..utils import is_iterable
38
from ..utils import to_tuple
39
from .io import check_uint_to_int
40
from .io import ensure_4d
41
from .io import get_rotation_and_spacing_from_affine
42
from .io import get_sitk_metadata_from_ras_affine
43
from .io import nib_to_sitk
44
from .io import read_affine
45
from .io import read_image
46
from .io import read_shape
47
from .io import sitk_to_nib
48
from .io import write_image
49
50
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
51
TypeBound = tuple[float, float]
52
TypeBounds = tuple[TypeBound, TypeBound, TypeBound]
53
54
deprecation_message = (
55
    'Setting the image data with the property setter is deprecated. Use the'
56
    ' set_data() method instead'
57
)
58
59
60
class Image(dict):
61
    r"""TorchIO image.
62
63
    For information about medical image orientation, check out `NiBabel docs`_,
64
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
65
    `SimpleITK docs`_.
66
67
    Args:
68
        path: Path to a file or sequence of paths to files that can be read by
69
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
70
            DICOM files. If :attr:`tensor` is given, the data in
71
            :attr:`path` will not be read.
72
            If a sequence of paths is given, data
73
            will be concatenated on the channel dimension so spatial
74
            dimensions must match.
75
        type: Type of image, such as :attr:`torchio.INTENSITY` or
76
            :attr:`torchio.LABEL`. This will be used by the transforms to
77
            decide whether to apply an operation, or which interpolation to use
78
            when resampling. For example, `preprocessing`_ and `augmentation`_
79
            intensity transforms will only be applied to images with type
80
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
81
            all types, and nearest neighbor interpolation is always used to
82
            resample images with type :attr:`torchio.LABEL`.
83
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
84
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
85
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
86
            :class:`torch.Tensor` or NumPy array with dimensions
87
            :math:`(C, W, H, D)`.
88
        affine: :math:`4 \times 4` matrix to convert voxel coordinates to world
89
            coordinates. If ``None``, an identity matrix will be used. See the
90
            `NiBabel docs on coordinates`_ for more information.
91
        check_nans: If ``True``, issues a warning if NaNs are found
92
            in the image. If ``False``, images will not be checked for the
93
            presence of NaNs.
94
        reader: Callable object that takes a path and returns a 4D tensor and a
95
            2D, :math:`4 \times 4` affine matrix. This can be used if your data
96
            is saved in a custom format, such as ``.npy`` (see example below).
97
            If the affine matrix is ``None``, an identity matrix will be used.
98
        **kwargs: Items that will be added to the image dictionary, e.g.
99
            acquisition parameters or image ID.
100
        verify_path: If ``True``, the path will be checked to see if it exists. If
101
            ``False``, the path will not be verified. This is useful when it is
102
            expensive to check the path, e.g., when reading a large dataset from a
103
            mounted drive.
104
105
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
106
    when needed.
107
108
    Example:
109
        >>> import torchio as tio
110
        >>> import numpy as np
111
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
112
        >>> image  # not loaded yet
113
        ScalarImage(path: t1.nii.gz; type: intensity)
114
        >>> times_two = 2 * image.data  # data is loaded and cached here
115
        >>> image
116
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
117
        >>> image.save('doubled_image.nii.gz')
118
        >>> def numpy_reader(path):
119
        ...     data = np.load(path).as_type(np.float32)
120
        ...     affine = np.eye(4)
121
        ...     return data, affine
122
        >>> image = tio.ScalarImage('t1.npy', reader=numpy_reader)
123
124
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
125
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
126
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
127
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
128
    .. _NiBabel docs on coordinates: https://nipy.org/nibabel/coordinate_systems.html#the-affine-matrix-as-a-transformation-between-spaces
129
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
130
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
131
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
132
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
133
    """
134
135
    def __init__(
136
        self,
137
        path: TypePath | Sequence[TypePath] | None = None,
138
        type: str | None = None,  # noqa: A002
139
        tensor: TypeData | None = None,
140
        affine: TypeData | None = None,
141
        check_nans: bool = False,  # removed by ITK by default
142
        reader: Callable[[TypePath], TypeDataAffine] = read_image,
143
        verify_path: bool = True,
144
        **kwargs: dict[str, Any],
145
    ):
146
        self.check_nans = check_nans
147
        self.reader = reader
148
149
        if type is None:
150
            warnings.warn(
151
                'Not specifying the image type is deprecated and will be'
152
                ' mandatory in the future. You can probably use'
153
                ' tio.ScalarImage or tio.LabelMap instead',
154
                FutureWarning,
155
                stacklevel=2,
156
            )
157
            type = INTENSITY  # noqa: A001
158
159
        if path is None and tensor is None:
160
            raise ValueError('A value for path or tensor must be given')
161
        self._loaded = False
162
163
        tensor = self._parse_tensor(tensor)
164
        affine = self._parse_affine(affine)
165
        if tensor is not None:
166
            self.set_data(tensor)
167
            self.affine = affine
168
            self._loaded = True
169
        for key in PROTECTED_KEYS:
170
            if key in kwargs:
171
                message = f'Key "{key}" is reserved. Use a different one'
172
                raise ValueError(message)
173
        if 'channels_last' in kwargs:
174
            message = (
175
                'The "channels_last" keyword argument is deprecated after'
176
                ' https://github.com/TorchIO-project/torchio/pull/685 and will be'
177
                ' removed in the future'
178
            )
179
            warnings.warn(message, FutureWarning, stacklevel=2)
180
181
        super().__init__(**kwargs)
182
        self._check_data_loader()
183
        self.path = self._parse_path(path, verify=verify_path)
184
185
        self[PATH] = '' if self.path is None else str(self.path)
186
        self[STEM] = '' if self.path is None else get_stem(self.path)
187
        self[TYPE] = type
188
189
    def __repr__(self):
190
        properties = []
191
        properties.extend(
192
            [
193
                f'shape: {self.shape}',
194
                f'spacing: {self.get_spacing_string()}',
195
                f'orientation: {"".join(self.orientation)}+',
196
            ]
197
        )
198
        if self._loaded:
199
            properties.append(f'dtype: {self.data.type()}')
200
            natural = humanize.naturalsize(self.memory, binary=True)
201
            properties.append(f'memory: {natural}')
202
        else:
203
            properties.append(f'path: "{self.path}"')
204
205
        properties = '; '.join(properties)
206
        string = f'{self.__class__.__name__}({properties})'
207
        return string
208
209
    def __getitem__(self, item):
210
        if isinstance(item, (slice, int, tuple)):
211
            return self._crop_from_slices(item)
212
213
        if item in (DATA, AFFINE):
214
            if item not in self:
215
                self.load()
216
        return super().__getitem__(item)
217
218
    def __array__(self):
219
        return self.data.numpy()
220
221
    def __copy__(self):
222
        kwargs = {
223
            TYPE: self.type,
224
            PATH: self.path,
225
        }
226
        if self._loaded:
227
            kwargs[TENSOR] = self.data
228
            kwargs[AFFINE] = self.affine
229
        for key, value in self.items():
230
            if key in PROTECTED_KEYS:
231
                continue
232
            kwargs[key] = value  # should I copy? deepcopy?
233
        new_image_class = type(self)
234
        new_image = new_image_class(
235
            check_nans=self.check_nans,
236
            reader=self.reader,
237
            **kwargs,
238
        )
239
        return new_image
240
241
    @staticmethod
242
    def _check_data_loader() -> None:
243
        if torch.__version__ >= '2.3' and in_torch_loader():
244
            message = (
245
                'Using TorchIO images without a torchio.SubjectsLoader in PyTorch >='
246
                ' 2.3 might have unexpected consequences, e.g., the collated batches'
247
                ' will be instances of torchio.Subject with 5D images. Replace'
248
                ' your PyTorch DataLoader with a torchio.SubjectsLoader so that'
249
                ' the collated batch becomes a dictionary, as expected. See'
250
                ' https://github.com/TorchIO-project/torchio/issues/1179 for more'
251
                ' context about this issue.'
252
            )
253
            warnings.warn(message, stacklevel=1)
254
255
    @property
256
    def data(self) -> torch.Tensor:
257
        """Tensor data (same as :class:`Image.tensor`)."""
258
        return self[DATA]
259
260
    @data.setter  # type: ignore[misc]
261
    @deprecated(version='0.18.16', reason=deprecation_message)
262
    def data(self, tensor: TypeData):
263
        self.set_data(tensor)
264
265
    def set_data(self, tensor: TypeData):
266
        """Store a 4D tensor in the :attr:`data` key and attribute.
267
268
        Args:
269
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
270
        """
271
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
272
273
    @property
274
    def tensor(self) -> torch.Tensor:
275
        """Tensor data (same as :class:`Image.data`)."""
276
        return self.data
277
278
    @property
279
    def affine(self) -> np.ndarray:
280
        """Affine matrix to transform voxel indices into world coordinates."""
281
        # If path is a dir (probably DICOM), just load the data
282
        # Same if it's a list of paths (used to create a 4D image)
283
        # Finally, if we use a custom reader, SimpleITK probably won't be able
284
        # to read the metadata, so we resort to loading everything into memory
285
        is_custom_reader = self.reader is not read_image
286
        if self._loaded or self._is_dir() or self._is_multipath() or is_custom_reader:
287
            affine = self[AFFINE]
288
        else:
289
            assert self.path is not None
290
            assert isinstance(self.path, (str, Path))
291
            affine = read_affine(self.path)
292
        return affine
293
294
    @affine.setter
295
    def affine(self, matrix):
296
        self[AFFINE] = self._parse_affine(matrix)
297
298
    @property
299
    def type(self) -> str:  # noqa: A003
300
        return self[TYPE]
301
302
    @property
303
    def shape(self) -> TypeQuartetInt:
304
        """Tensor shape as :math:`(C, W, H, D)`."""
305
        custom_reader = self.reader is not read_image
306
        multipath = self._is_multipath()
307
        if isinstance(self.path, Path):
308
            is_dir = self.path.is_dir()
309
        shape: TypeQuartetInt
310
        if self._loaded or custom_reader or multipath or is_dir:
0 ignored issues
show
introduced by
The variable is_dir does not seem to be defined in case isinstance(self.path, Path) on line 307 is False. Are you sure this can never be the case?
Loading history...
311
            channels, si, sj, sk = self.data.shape
312
            shape = channels, si, sj, sk
313
        else:
314
            assert isinstance(self.path, (str, Path))
315
            shape = read_shape(self.path)
316
        return shape
317
318
    @property
319
    def spatial_shape(self) -> TypeTripletInt:
320
        """Tensor spatial shape as :math:`(W, H, D)`."""
321
        return self.shape[1:]
322
323
    def check_is_2d(self) -> None:
324
        if not self.is_2d():
325
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
326
            raise RuntimeError(message)
327
328
    @property
329
    def height(self) -> int:
330
        """Image height, if 2D."""
331
        self.check_is_2d()
332
        return self.spatial_shape[1]
333
334
    @property
335
    def width(self) -> int:
336
        """Image width, if 2D."""
337
        self.check_is_2d()
338
        return self.spatial_shape[0]
339
340
    @property
341
    def orientation(self) -> tuple[str, str, str]:
342
        """Orientation codes."""
343
        return nib.aff2axcodes(self.affine)
344
345
    @property
346
    def direction(self) -> TypeDirection3D:
347
        _, _, direction = get_sitk_metadata_from_ras_affine(
348
            self.affine,
349
            lps=False,
350
        )
351
        return direction  # type: ignore[return-value]
352
353
    @property
354
    def spacing(self) -> tuple[float, float, float]:
355
        """Voxel spacing in mm."""
356
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
357
        sx, sy, sz = spacing
358
        return sx, sy, sz
359
360
    @property
361
    def origin(self) -> tuple[float, float, float]:
362
        """Center of first voxel in array, in mm."""
363
        ox, oy, oz = self.affine[:3, 3]
364
        return ox, oy, oz
365
366
    @property
367
    def itemsize(self):
368
        """Element size of the data type."""
369
        return self.data.element_size()
370
371
    @property
372
    def memory(self) -> float:
373
        """Number of Bytes that the tensor takes in the RAM."""
374
        return np.prod(self.shape) * self.itemsize
375
376
    @property
377
    def bounds(self) -> np.ndarray:
378
        """Position of centers of voxels in smallest and largest indices."""
379
        ini = 0, 0, 0
380
        fin = np.array(self.spatial_shape) - 1
381
        point_ini = apply_affine(self.affine, ini)
382
        point_fin = apply_affine(self.affine, fin)
383
        return np.array((point_ini, point_fin))
384
385
    @property
386
    def num_channels(self) -> int:
387
        """Get the number of channels in the associated 4D tensor."""
388
        return len(self.data)
389
390
    def axis_name_to_index(self, axis: str) -> int:
391
        """Convert an axis name to an axis index.
392
393
        Args:
394
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
395
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
396
                versions and first letters are also valid, as only the first
397
                letter will be used.
398
399
        .. note:: If you are working with animals, you should probably use
400
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
401
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
402
            respectively.
403
404
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
405
            ``'Left'`` and ``'Right'``.
406
        """
407
        # Top and bottom are used for the vertical 2D axis as the use of
408
        # Height vs Horizontal might be ambiguous
409
410
        if not isinstance(axis, str):
411
            raise ValueError('Axis must be a string')
412
        axis = axis[0].upper()
413
414
        # Generally, TorchIO tensors are (C, W, H, D)
415
        if axis in 'TB':  # Top, Bottom
416
            return -2
417
        else:
418
            try:
419
                index = self.orientation.index(axis)
420
            except ValueError:
421
                index = self.orientation.index(self.flip_axis(axis))
422
            # Return negative indices so that it does not matter whether we
423
            # refer to spatial dimensions or not
424
            index = -3 + index
425
            return index
426
427
    @staticmethod
428
    def flip_axis(axis: str) -> str:
429
        """Return the opposite axis label. For example, ``'L'`` -> ``'R'``.
430
431
        Args:
432
            axis: Axis label, such as ``'L'`` or ``'left'``.
433
        """
434
        labels = 'LRPAISTBDV'
435
        first = labels[::2]
436
        last = labels[1::2]
437
        flip_dict = dict(zip(first + last, last + first))
438
        axis = axis[0].upper()
439
        flipped_axis = flip_dict.get(axis)
440
        if flipped_axis is None:
441
            values = ', '.join(labels)
442
            message = f'Axis not understood. Please use one of: {values}'
443
            raise ValueError(message)
444
        return flipped_axis
445
446
    def get_spacing_string(self) -> str:
447
        strings = [f'{n:.2f}' for n in self.spacing]
448
        string = f'({", ".join(strings)})'
449
        return string
450
451
    def get_bounds(self) -> TypeBounds:
452
        """Get minimum and maximum world coordinates occupied by the image."""
453
        first_index = 3 * (-0.5,)
454
        last_index = np.array(self.spatial_shape) - 0.5
455
        first_point = apply_affine(self.affine, first_index)
456
        last_point = apply_affine(self.affine, last_index)
457
        array = np.array((first_point, last_point))
458
        bounds_x, bounds_y, bounds_z = array.T.tolist()  # type: ignore[misc]
459
        return bounds_x, bounds_y, bounds_z  # type: ignore[return-value]
460
461
    @staticmethod
462
    def _parse_single_path(
463
        path: TypePath,
464
        *,
465
        verify: bool = True,
466
    ) -> Path:
467
        try:
468
            path = Path(path).expanduser()
469
        except TypeError as err:
470
            message = (
471
                f'Expected type str or Path but found {path} with type'
472
                f' {type(path)} instead'
473
            )
474
            raise TypeError(message) from err
475
        except RuntimeError as err:
476
            message = f'Conversion to path not possible for variable: {path}'
477
            raise RuntimeError(message) from err
478
        if not verify:
479
            return path
480
481
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
482
            raise FileNotFoundError(f'File not found: "{path}"')
483
        return path
484
485
    def _parse_path(
486
        self,
487
        path: TypePath | Sequence[TypePath] | None,
488
        *,
489
        verify: bool = True,
490
    ) -> Path | list[Path] | None:
491
        if path is None:
492
            return None
493
        elif isinstance(path, dict):
494
            # https://github.com/TorchIO-project/torchio/pull/838
495
            raise TypeError('The path argument cannot be a dictionary')
496
        elif self._is_paths_sequence(path):
497
            return [self._parse_single_path(p, verify=verify) for p in path]  # type: ignore[union-attr]
498
        else:
499
            return self._parse_single_path(path, verify=verify)  # type: ignore[arg-type]
500
501
    def _parse_tensor(
502
        self,
503
        tensor: TypeData | None,
504
        none_ok: bool = True,
505
    ) -> torch.Tensor | None:
506
        if tensor is None:
507
            if none_ok:
508
                return None
509
            else:
510
                raise RuntimeError('Input tensor cannot be None')
511
        if isinstance(tensor, np.ndarray):
512
            tensor = check_uint_to_int(tensor)
513
            tensor = torch.as_tensor(tensor)
514
        elif not isinstance(tensor, torch.Tensor):
515
            message = (
516
                'Input tensor must be a PyTorch tensor or NumPy array,'
517
                f' but type "{type(tensor)}" was found'
518
            )
519
            raise TypeError(message)
520
        ndim = tensor.ndim
521
        if ndim != 4:
522
            raise ValueError(f'Input tensor must be 4D, but it is {ndim}D')
523
        if tensor.dtype == torch.bool:
524
            tensor = tensor.to(torch.uint8)
525
        if self.check_nans and torch.isnan(tensor).any():
526
            warnings.warn('NaNs found in tensor', RuntimeWarning, stacklevel=2)
527
        return tensor
528
529
    @staticmethod
530
    def _parse_tensor_shape(tensor: torch.Tensor) -> TypeData:
531
        return ensure_4d(tensor)
532
533
    @staticmethod
534
    def _parse_affine(affine: TypeData | None) -> np.ndarray:
535
        if affine is None:
536
            return np.eye(4)
537
        if isinstance(affine, torch.Tensor):
538
            affine = affine.numpy()
539
        if not isinstance(affine, np.ndarray):
540
            bad_type = type(affine)
541
            raise TypeError(f'Affine must be a NumPy array, not {bad_type}')
542
        if affine.shape != (4, 4):
543
            bad_shape = affine.shape
544
            raise ValueError(f'Affine shape must be (4, 4), not {bad_shape}')
545
        return affine.astype(np.float64)
546
547
    @staticmethod
548
    def _is_paths_sequence(path: TypePath | Sequence[TypePath] | None) -> bool:
549
        is_not_string = not isinstance(path, str)
550
        return is_not_string and is_iterable(path)
551
552
    def _is_multipath(self) -> bool:
553
        return self._is_paths_sequence(self.path)
554
555
    def _is_dir(self) -> bool:
556
        is_sequence = self._is_multipath()
557
        if is_sequence:
558
            return False
559
        elif self.path is None:
560
            return False
561
        else:
562
            assert isinstance(self.path, Path)
563
            return self.path.is_dir()
564
565
    def load(self) -> None:
566
        r"""Load the image from disk.
567
568
        Returns:
569
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
570
            :math:`4 \times 4` affine matrix to convert voxel indices to world
571
            coordinates.
572
        """
573
        if self._loaded:
574
            return
575
576
        paths: list[Path]
577
        if self._is_multipath():
578
            paths = self.path  # type: ignore[assignment]
579
        else:
580
            paths = [self.path]  # type: ignore[list-item]
581
        tensor, affine = self.read_and_check(paths[0])
582
        tensors = [tensor]
583
        for path in paths[1:]:
584
            new_tensor, new_affine = self.read_and_check(path)
585
            if not np.array_equal(affine, new_affine):
586
                message = (
587
                    'Files have different affine matrices.'
588
                    f'\nMatrix of {paths[0]}:'
589
                    f'\n{affine}'
590
                    f'\nMatrix of {path}:'
591
                    f'\n{new_affine}'
592
                )
593
                warnings.warn(message, RuntimeWarning, stacklevel=2)
594
            if not tensor.shape[1:] == new_tensor.shape[1:]:
595
                message = (
596
                    f'Files shape do not match, found {tensor.shape}'
597
                    f'and {new_tensor.shape}'
598
                )
599
                RuntimeError(message)
600
            tensors.append(new_tensor)
601
        tensor = torch.cat(tensors)
602
        self.set_data(tensor)
603
        self.affine = affine
604
        self._loaded = True
605
606
    def unload(self) -> None:
607
        """Unload the image from memory.
608
609
        Raises:
610
            RuntimeError: If the images has not been loaded yet or if no path
611
                is available.
612
        """
613
        if not self._loaded:
614
            message = 'Image cannot be unloaded as it has not been loaded yet'
615
            raise RuntimeError(message)
616
        if self.path is None:
617
            message = (
618
                'Cannot unload image as no path is available'
619
                ' from where the image could be loaded again'
620
            )
621
            raise RuntimeError(message)
622
        self[DATA] = None
623
        self[AFFINE] = None
624
        self._loaded = False
625
626
    def read_and_check(self, path: TypePath) -> TypeDataAffine:
627
        tensor, affine = self.reader(path)
628
        # Make sure the data type is compatible with PyTorch
629
        if self.reader is not read_image and isinstance(tensor, np.ndarray):
630
            tensor = check_uint_to_int(tensor)
631
        tensor = self._parse_tensor_shape(tensor)  # type: ignore[assignment]
632
        tensor = self._parse_tensor(tensor)  # type: ignore[assignment]
633
        affine = self._parse_affine(affine)
634
        if self.check_nans and torch.isnan(tensor).any():
635
            warnings.warn(
636
                f'NaNs found in file "{path}"',
637
                RuntimeWarning,
638
                stacklevel=2,
639
            )
640
        return tensor, affine
641
642
    def save(self, path: TypePath, squeeze: bool | None = None) -> None:
643
        """Save image to disk.
644
645
        Args:
646
            path: String or instance of :class:`pathlib.Path`.
647
            squeeze: Whether to remove singleton dimensions before saving.
648
                If ``None``, the array will be squeezed if the output format is
649
                JP(E)G, PNG, BMP or TIF(F).
650
        """
651
        write_image(
652
            self.data,
653
            self.affine,
654
            path,
655
            squeeze=squeeze,
656
        )
657
658
    def is_2d(self) -> bool:
659
        return self.shape[-1] == 1
660
661
    def numpy(self) -> np.ndarray:
662
        """Get a NumPy array containing the image data."""
663
        return np.asarray(self)
664
665
    def as_sitk(self, **kwargs) -> sitk.Image:
666
        """Get the image as an instance of :class:`sitk.Image`."""
667
        return nib_to_sitk(self.data, self.affine, **kwargs)
668
669
    @classmethod
670
    def from_sitk(cls, sitk_image):
671
        """Instantiate a new TorchIO image from a :class:`sitk.Image`.
672
673
        Example:
674
            >>> import torchio as tio
675
            >>> import SimpleITK as sitk
676
            >>> sitk_image = sitk.Image(20, 30, 40, sitk.sitkUInt16)
677
            >>> tio.LabelMap.from_sitk(sitk_image)
678
            LabelMap(shape: (1, 20, 30, 40); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 93.8 KiB; dtype: torch.IntTensor)
679
            >>> sitk_image = sitk.Image((224, 224), sitk.sitkVectorFloat32, 3)
680
            >>> tio.ScalarImage.from_sitk(sitk_image)
681
            ScalarImage(shape: (3, 224, 224, 1); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 588.0 KiB; dtype: torch.FloatTensor)
682
        """
683
        tensor, affine = sitk_to_nib(sitk_image)
684
        return cls(tensor=tensor, affine=affine)
685
686
    def as_pil(self, transpose=True):
687
        """Get the image as an instance of :class:`PIL.Image`.
688
689
        .. note:: Values will be clamped to 0-255 and cast to uint8.
690
691
        .. note:: To use this method, Pillow needs to be installed:
692
            ``pip install Pillow``.
693
        """
694
        try:
695
            from PIL import Image as ImagePIL
696
        except ModuleNotFoundError as e:
697
            message = 'Please install Pillow to use Image.as_pil(): pip install Pillow'
698
            raise RuntimeError(message) from e
699
700
        self.check_is_2d()
701
        tensor = self.data
702
        if len(tensor) not in (1, 3, 4):
703
            raise NotImplementedError(
704
                'Only 1, 3 or 4 channels are supported for conversion to Pillow image'
705
            )
706
        if len(tensor) == 1:
707
            tensor = torch.cat(3 * [tensor])
708
        if transpose:
709
            tensor = tensor.permute(3, 2, 1, 0)
710
        else:
711
            tensor = tensor.permute(3, 1, 2, 0)
712
        array = tensor.clamp(0, 255).numpy()[0]
713
        return ImagePIL.fromarray(array.astype(np.uint8))
714
715
    def to_gif(
716
        self,
717
        axis: int,
718
        duration: float,  # of full gif
719
        output_path: TypePath,
720
        loop: int = 0,
721
        rescale: bool = True,
722
        optimize: bool = True,
723
        reverse: bool = False,
724
    ) -> None:
725
        """Save an animated GIF of the image.
726
727
        Args:
728
            axis: Spatial axis (0, 1 or 2).
729
            duration: Duration of the full animation in seconds.
730
            output_path: Path to the output GIF file.
731
            loop: Number of times the GIF should loop.
732
                ``0`` means that it will loop forever.
733
            rescale: Use :class:`~torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity`
734
                to rescale the intensity values to :math:`[0, 255]`.
735
            optimize: If ``True``, attempt to compress the palette by
736
                eliminating unused colors. This is only useful if the palette
737
                can be compressed to the next smaller power of 2 elements.
738
            reverse: Reverse the temporal order of frames.
739
        """
740
        from ..visualization import make_gif  # avoid circular import
741
742
        make_gif(
743
            self.data,
744
            axis,
745
            duration,
746
            output_path,
747
            loop=loop,
748
            rescale=rescale,
749
            optimize=optimize,
750
            reverse=reverse,
751
        )
752
753
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
754
        """Get image center in RAS+ or LPS+ coordinates.
755
756
        Args:
757
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
758
                the first dimension grows towards the left, etc. Otherwise, the
759
                coordinates will be in RAS+ orientation.
760
        """
761
        size = np.array(self.spatial_shape)
762
        center_index = (size - 1) / 2
763
        r, a, s = apply_affine(self.affine, center_index)
764
        if lps:
765
            return (-r, -a, s)
766
        else:
767
            return (r, a, s)
768
769
    def set_check_nans(self, check_nans: bool) -> None:
770
        self.check_nans = check_nans
771
772
    def plot(self, **kwargs) -> None:
773
        """Plot image."""
774
        if self.is_2d():
775
            self.as_pil().show()
776
        else:
777
            from ..visualization import plot_volume  # avoid circular import
778
779
            plot_volume(self, **kwargs)
780
781
    def show(self, viewer_path: TypePath | None = None) -> None:
782
        """Open the image using external software.
783
784
        Args:
785
            viewer_path: Path to the application used to view the image. If
786
                ``None``, the value of the environment variable
787
                ``SITK_SHOW_COMMAND`` will be used. If this variable is also
788
                not set, TorchIO will try to guess the location of
789
                `ITK-SNAP <http://www.itksnap.org/pmwiki/pmwiki.php>`_ and
790
                `3D Slicer <https://www.slicer.org/>`_.
791
792
        Raises:
793
            RuntimeError: If the viewer is not found.
794
        """
795
        sitk_image = self.as_sitk()
796
        image_viewer = sitk.ImageViewer()
797
        # This is so that 3D Slicer creates segmentation nodes from label maps
798
        if self.__class__.__name__ == 'LabelMap':
799
            image_viewer.SetFileExtension('.seg.nrrd')
800
        if viewer_path is not None:
801
            image_viewer.SetApplication(str(viewer_path))
802
        try:
803
            image_viewer.Execute(sitk_image)
804
        except RuntimeError as e:
805
            viewer_path = guess_external_viewer()
806
            if viewer_path is None:
807
                message = (
808
                    'No external viewer has been found. Please set the'
809
                    ' environment variable SITK_SHOW_COMMAND to a viewer of'
810
                    ' your choice'
811
                )
812
                raise RuntimeError(message) from e
813
            image_viewer.SetApplication(str(viewer_path))
814
            image_viewer.Execute(sitk_image)
815
816
    def _crop_from_slices(
817
        self,
818
        slices: TypeSlice | tuple[TypeSlice, ...],
819
    ) -> Image:
820
        from ..transforms import Crop
821
822
        slices_tuple = to_tuple(slices)  # type: ignore[assignment]
823
        cropping: list[int] = []
824
        for dim, slice_ in enumerate(slices_tuple):
825
            if isinstance(slice_, slice):
826
                pass
827
            elif slice_ is Ellipsis:
828
                message = 'Ellipsis slicing is not supported yet'
829
                raise NotImplementedError(message)
830
            elif isinstance(slice_, int):
831
                slice_ = slice(slice_, slice_ + 1)  # type: ignore[assignment]
832
            else:
833
                message = f'Slice type not understood: "{type(slice_)}"'
834
                raise TypeError(message)
835
            shape_dim = self.spatial_shape[dim]
836
            assert isinstance(slice_, slice)
837
            start, stop, step = slice_.indices(shape_dim)
838
            if step != 1:
839
                message = (
840
                    'Slicing with steps different from 1 is not supported yet.'
841
                    ' Use the Crop transform instead'
842
                )
843
                raise ValueError(message)
844
            crop_ini = start
845
            crop_fin = shape_dim - stop
846
            cropping.extend([crop_ini, crop_fin])
847
        while dim < 2:
0 ignored issues
show
introduced by
The variable dim does not seem to be defined in case the for loop on line 824 is not entered. Are you sure this can never be the case?
Loading history...
848
            cropping.extend([0, 0])
849
            dim += 1
850
        w_ini, w_fin, h_ini, h_fin, d_ini, d_fin = cropping
851
        cropping_arg = w_ini, w_fin, h_ini, h_fin, d_ini, d_fin  # making mypy happy
852
        return Crop(cropping_arg)(self)  # type: ignore[return-value]
853
854
855
class ScalarImage(Image):
856
    """Image whose pixel values represent scalars.
857
858
    Example:
859
        >>> import torch
860
        >>> import torchio as tio
861
        >>> # Loading from a file
862
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
863
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
864
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
865
        >>> data, affine = image.data, image.affine
866
        >>> affine.shape
867
        (4, 4)
868
        >>> image.data is image[tio.DATA]
869
        True
870
        >>> image.data is image.tensor
871
        True
872
        >>> type(image.data)
873
        torch.Tensor
874
875
    See :class:`~torchio.Image` for more information.
876
    """
877
878
    def __init__(self, *args, **kwargs):
879
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
880
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
881
        kwargs.update({'type': INTENSITY})
882
        super().__init__(*args, **kwargs)
883
884
    def hist(self, **kwargs) -> None:
885
        """Plot histogram."""
886
        from ..visualization import plot_histogram
887
888
        x = self.data.flatten().numpy()
889
        plot_histogram(x, **kwargs)
890
891
892
class LabelMap(Image):
893
    """Image whose pixel values represent segmentation labels.
894
895
    A sequence of paths to 3D images can be passed to create a 4D image.
896
    This is useful to create a
897
    `tissue probability map (TPM) <https://andysbrainbook.readthedocs.io/en/latest/SPM/SPM_Short_Course/SPM_04_Preprocessing/04_SPM_Segmentation.html#tissue-probability-maps>`,
898
    which contains the probability of each voxel belonging to a certain tissue type,
899
    or to create a label map with overlapping labels.
900
901
    Intensity transforms are not applied to these images.
902
903
    Nearest neighbor interpolation is always used to resample label maps,
904
    independently of the specified interpolation type in the transform
905
    instantiation.
906
907
    Example:
908
        >>> import torch
909
        >>> import torchio as tio
910
        >>> binary_tensor = torch.rand(1, 128, 128, 68) > 0.5
911
        >>> label_map = tio.LabelMap(tensor=binary_tensor)  # from a tensor
912
        >>> label_map = tio.LabelMap('t1_seg.nii.gz')  # from a file
913
        >>> # Create a 4D tissue probability map from different 3D images
914
        >>> tissues = 'gray_matter.nii.gz', 'white_matter.nii.gz', 'csf.nii.gz'
915
        >>> tpm = tio.LabelMap(tissues)
916
917
    See :class:`~torchio.Image` for more information.
918
    """
919
920
    def __init__(self, *args, **kwargs):
921
        if 'type' in kwargs and kwargs['type'] != LABEL:
922
            raise ValueError('Type of LabelMap is always torchio.LABEL')
923
        kwargs.update({'type': LABEL})
924
        super().__init__(*args, **kwargs)
925
926
    def count_nonzero(self) -> int:
927
        """Get the number of voxels that are not 0."""
928
        return int(self.data.count_nonzero().item())
929
930
    def count_labels(self) -> dict[int, int]:
931
        """Get the number of voxels in each label."""
932
        values_list = self.data.flatten().tolist()
933
        counter = Counter(values_list)
934
        counts = {label: counter[label] for label in sorted(counter)}
935
        return counts
936