Passed
Pull Request — main (#1327)
by Fernando
01:27
created

torchio.data.image.Image.update_attributes()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
from __future__ import annotations
2
3
import warnings
4
from collections import Counter
5
from collections.abc import Sequence
6
from pathlib import Path
7
from typing import Any
8
from typing import Callable
9
10
import humanize
11
import nibabel as nib
12
import numpy as np
13
import SimpleITK as sitk
14
import torch
15
from deprecated import deprecated
16
from nibabel.affines import apply_affine
17
18
from ..constants import AFFINE
19
from ..constants import DATA
20
from ..constants import INTENSITY
21
from ..constants import LABEL
22
from ..constants import PATH
23
from ..constants import STEM
24
from ..constants import TENSOR
25
from ..constants import TYPE
26
from ..types import TypeData
27
from ..types import TypeDataAffine
28
from ..types import TypeDirection3D
29
from ..types import TypePath
30
from ..types import TypeQuartetInt
31
from ..types import TypeSlice
32
from ..types import TypeTripletFloat
33
from ..types import TypeTripletInt
34
from ..utils import get_stem
35
from ..utils import guess_external_viewer
36
from ..utils import in_torch_loader
37
from ..utils import is_iterable
38
from ..utils import to_tuple
39
from .io import check_uint_to_int
40
from .io import ensure_4d
41
from .io import get_rotation_and_spacing_from_affine
42
from .io import get_sitk_metadata_from_ras_affine
43
from .io import nib_to_sitk
44
from .io import read_affine
45
from .io import read_image
46
from .io import read_shape
47
from .io import sitk_to_nib
48
from .io import write_image
49
50
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
51
TypeBound = tuple[float, float]
52
TypeBounds = tuple[TypeBound, TypeBound, TypeBound]
53
54
deprecation_message = (
55
    'Setting the image data with the property setter is deprecated. Use the'
56
    ' set_data() method instead'
57
)
58
59
60
class Image(dict):
61
    r"""TorchIO image.
62
63
    For information about medical image orientation, check out `NiBabel docs`_,
64
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
65
    `SimpleITK docs`_.
66
67
    Args:
68
        path: Path to a file or sequence of paths to files that can be read by
69
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
70
            DICOM files. If :attr:`tensor` is given, the data in
71
            :attr:`path` will not be read.
72
            If a sequence of paths is given, data
73
            will be concatenated on the channel dimension so spatial
74
            dimensions must match.
75
        type: Type of image, such as :attr:`torchio.INTENSITY` or
76
            :attr:`torchio.LABEL`. This will be used by the transforms to
77
            decide whether to apply an operation, or which interpolation to use
78
            when resampling. For example, `preprocessing`_ and `augmentation`_
79
            intensity transforms will only be applied to images with type
80
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
81
            all types, and nearest neighbor interpolation is always used to
82
            resample images with type :attr:`torchio.LABEL`.
83
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
84
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
85
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
86
            :class:`torch.Tensor` or NumPy array with dimensions
87
            :math:`(C, W, H, D)`.
88
        affine: :math:`4 \times 4` matrix to convert voxel coordinates to world
89
            coordinates. If ``None``, an identity matrix will be used. See the
90
            `NiBabel docs on coordinates`_ for more information.
91
        check_nans: If ``True``, issues a warning if NaNs are found
92
            in the image. If ``False``, images will not be checked for the
93
            presence of NaNs.
94
        reader: Callable object that takes a path and returns a 4D tensor and a
95
            2D, :math:`4 \times 4` affine matrix. This can be used if your data
96
            is saved in a custom format, such as ``.npy`` (see example below).
97
            If the affine matrix is ``None``, an identity matrix will be used.
98
        **kwargs: Items that will be added to the image dictionary, e.g.
99
            acquisition parameters.
100
101
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
102
    when needed.
103
104
    Example:
105
        >>> import torchio as tio
106
        >>> import numpy as np
107
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
108
        >>> image  # not loaded yet
109
        ScalarImage(path: t1.nii.gz; type: intensity)
110
        >>> times_two = 2 * image.data  # data is loaded and cached here
111
        >>> image
112
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
113
        >>> image.save('doubled_image.nii.gz')
114
        >>> def numpy_reader(path):
115
        ...     data = np.load(path).as_type(np.float32)
116
        ...     affine = np.eye(4)
117
        ...     return data, affine
118
        >>> image = tio.ScalarImage('t1.npy', reader=numpy_reader)
119
120
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
121
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
122
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
123
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
124
    .. _NiBabel docs on coordinates: https://nipy.org/nibabel/coordinate_systems.html#the-affine-matrix-as-a-transformation-between-spaces
125
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
126
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
127
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
128
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
129
    """
130
131
    def __init__(
132
        self,
133
        path: TypePath | Sequence[TypePath] | None = None,
134
        type: str | None = None,  # noqa: A002
135
        tensor: TypeData | None = None,
136
        affine: TypeData | None = None,
137
        check_nans: bool = False,  # removed by ITK by default
138
        reader: Callable[[TypePath], TypeDataAffine] = read_image,
139
        **kwargs: dict[str, Any],
140
    ):
141
        self.check_nans = check_nans
142
        self.reader = reader
143
144
        if type is None:
145
            warnings.warn(
146
                'Not specifying the image type is deprecated and will be'
147
                ' mandatory in the future. You can probably use'
148
                ' tio.ScalarImage or tio.LabelMap instead',
149
                FutureWarning,
150
                stacklevel=2,
151
            )
152
            type = INTENSITY  # noqa: A001
153
154
        if path is None and tensor is None:
155
            raise ValueError('A value for path or tensor must be given')
156
        self._loaded = False
157
158
        tensor = self._parse_tensor(tensor)
159
        affine = self._parse_affine(affine)
160
        if tensor is not None:
161
            self.set_data(tensor)
162
            self.affine = affine
163
            self._loaded = True
164
        for key in PROTECTED_KEYS:
165
            if key in kwargs:
166
                message = f'Key "{key}" is reserved. Use a different one'
167
                raise ValueError(message)
168
        if 'channels_last' in kwargs:
169
            message = (
170
                'The "channels_last" keyword argument is deprecated after'
171
                ' https://github.com/TorchIO-project/torchio/pull/685 and will be'
172
                ' removed in the future'
173
            )
174
            warnings.warn(message, FutureWarning, stacklevel=2)
175
176
        super().__init__(**kwargs)
177
        self._check_data_loader()
178
        self.path = self._parse_path(path)
179
180
        self[PATH] = '' if self.path is None else str(self.path)
181
        self[STEM] = '' if self.path is None else get_stem(self.path)
182
        self[TYPE] = type
183
184
        self.update_attributes()
185
186
    def __repr__(self):
187
        properties = []
188
        properties.extend(
189
            [
190
                f'shape: {self.shape}',
191
                f'spacing: {self.get_spacing_string()}',
192
                f'orientation: {"".join(self.orientation)}+',
193
            ]
194
        )
195
        if self._loaded:
196
            properties.append(f'dtype: {self.data.type()}')
197
            natural = humanize.naturalsize(self.memory, binary=True)
198
            properties.append(f'memory: {natural}')
199
        else:
200
            properties.append(f'path: "{self.path}"')
201
202
        properties = '; '.join(properties)
203
        string = f'{self.__class__.__name__}({properties})'
204
        return string
205
206
    def __getitem__(self, item):
207
        if isinstance(item, (slice, int, tuple)):
208
            return self._crop_from_slices(item)
209
210
        if item in (DATA, AFFINE):
211
            if item not in self:
212
                self.load()
213
        return super().__getitem__(item)
214
215
    def __array__(self):
216
        return self.data.numpy()
217
218
    def __copy__(self):
219
        kwargs = {
220
            TYPE: self.type,
221
            PATH: self.path,
222
        }
223
        if self._loaded:
224
            kwargs[TENSOR] = self.data
225
            kwargs[AFFINE] = self.affine
226
        for key, value in self.items():
227
            if key in PROTECTED_KEYS:
228
                continue
229
            kwargs[key] = value  # should I copy? deepcopy?
230
        new_image_class = type(self)
231
        new_image = new_image_class(
232
            check_nans=self.check_nans,
233
            reader=self.reader,
234
            **kwargs,
235
        )
236
        return new_image
237
238
    def update_attributes(self) -> None:
239
        # This allows to get images using attribute notation, e.g. image.modality
240
        self.__dict__.update(self)
241
242
    @staticmethod
243
    def _check_data_loader() -> None:
244
        if torch.__version__ >= '2.3' and in_torch_loader():
245
            message = (
246
                'Using TorchIO images without a torchio.SubjectsLoader in PyTorch >='
247
                ' 2.3 might have unexpected consequences, e.g., the collated batches'
248
                ' will be instances of torchio.Subject with 5D images. Replace'
249
                ' your PyTorch DataLoader with a torchio.SubjectsLoader so that'
250
                ' the collated batch becomes a dictionary, as expected. See'
251
                ' https://github.com/TorchIO-project/torchio/issues/1179 for more'
252
                ' context about this issue.'
253
            )
254
            warnings.warn(message, stacklevel=1)
255
256
    @property
257
    def data(self) -> torch.Tensor:
258
        """Tensor data (same as :class:`Image.tensor`)."""
259
        return self[DATA]
260
261
    @data.setter  # type: ignore[misc]
262
    @deprecated(version='0.18.16', reason=deprecation_message)
263
    def data(self, tensor: TypeData):
264
        self.set_data(tensor)
265
266
    def set_data(self, tensor: TypeData):
267
        """Store a 4D tensor in the :attr:`data` key and attribute.
268
269
        Args:
270
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
271
        """
272
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
273
274
    @property
275
    def tensor(self) -> torch.Tensor:
276
        """Tensor data (same as :class:`Image.data`)."""
277
        return self.data
278
279
    @property
280
    def affine(self) -> np.ndarray:
281
        """Affine matrix to transform voxel indices into world coordinates."""
282
        # If path is a dir (probably DICOM), just load the data
283
        # Same if it's a list of paths (used to create a 4D image)
284
        # Finally, if we use a custom reader, SimpleITK probably won't be able
285
        # to read the metadata, so we resort to loading everything into memory
286
        is_custom_reader = self.reader is not read_image
287
        if self._loaded or self._is_dir() or self._is_multipath() or is_custom_reader:
288
            affine = self[AFFINE]
289
        else:
290
            assert self.path is not None
291
            assert isinstance(self.path, (str, Path))
292
            affine = read_affine(self.path)
293
        return affine
294
295
    @affine.setter
296
    def affine(self, matrix):
297
        self[AFFINE] = self._parse_affine(matrix)
298
299
    @property
300
    def type(self) -> str:  # noqa: A003
301
        return self[TYPE]
302
303
    @property
304
    def shape(self) -> TypeQuartetInt:
305
        """Tensor shape as :math:`(C, W, H, D)`."""
306
        custom_reader = self.reader is not read_image
307
        multipath = self._is_multipath()
308
        if isinstance(self.path, Path):
309
            is_dir = self.path.is_dir()
310
        shape: TypeQuartetInt
311
        if self._loaded or custom_reader or multipath or is_dir:
0 ignored issues
show
introduced by
The variable is_dir does not seem to be defined in case isinstance(self.path, Path) on line 308 is False. Are you sure this can never be the case?
Loading history...
312
            channels, si, sj, sk = self.data.shape
313
            shape = channels, si, sj, sk
314
        else:
315
            assert isinstance(self.path, (str, Path))
316
            shape = read_shape(self.path)
317
        return shape
318
319
    @property
320
    def spatial_shape(self) -> TypeTripletInt:
321
        """Tensor spatial shape as :math:`(W, H, D)`."""
322
        return self.shape[1:]
323
324
    def check_is_2d(self) -> None:
325
        if not self.is_2d():
326
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
327
            raise RuntimeError(message)
328
329
    @property
330
    def height(self) -> int:
331
        """Image height, if 2D."""
332
        self.check_is_2d()
333
        return self.spatial_shape[1]
334
335
    @property
336
    def width(self) -> int:
337
        """Image width, if 2D."""
338
        self.check_is_2d()
339
        return self.spatial_shape[0]
340
341
    @property
342
    def orientation(self) -> tuple[str, str, str]:
343
        """Orientation codes."""
344
        return nib.aff2axcodes(self.affine)
345
346
    @property
347
    def direction(self) -> TypeDirection3D:
348
        _, _, direction = get_sitk_metadata_from_ras_affine(
349
            self.affine,
350
            lps=False,
351
        )
352
        return direction  # type: ignore[return-value]
353
354
    @property
355
    def spacing(self) -> tuple[float, float, float]:
356
        """Voxel spacing in mm."""
357
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
358
        sx, sy, sz = spacing
359
        return sx, sy, sz
360
361
    @property
362
    def origin(self) -> tuple[float, float, float]:
363
        """Center of first voxel in array, in mm."""
364
        ox, oy, oz = self.affine[:3, 3]
365
        return ox, oy, oz
366
367
    @property
368
    def itemsize(self):
369
        """Element size of the data type."""
370
        return self.data.element_size()
371
372
    @property
373
    def memory(self) -> float:
374
        """Number of Bytes that the tensor takes in the RAM."""
375
        return np.prod(self.shape) * self.itemsize
376
377
    @property
378
    def bounds(self) -> np.ndarray:
379
        """Position of centers of voxels in smallest and largest indices."""
380
        ini = 0, 0, 0
381
        fin = np.array(self.spatial_shape) - 1
382
        point_ini = apply_affine(self.affine, ini)
383
        point_fin = apply_affine(self.affine, fin)
384
        return np.array((point_ini, point_fin))
385
386
    @property
387
    def num_channels(self) -> int:
388
        """Get the number of channels in the associated 4D tensor."""
389
        return len(self.data)
390
391
    def axis_name_to_index(self, axis: str) -> int:
392
        """Convert an axis name to an axis index.
393
394
        Args:
395
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
396
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
397
                versions and first letters are also valid, as only the first
398
                letter will be used.
399
400
        .. note:: If you are working with animals, you should probably use
401
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
402
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
403
            respectively.
404
405
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
406
            ``'Left'`` and ``'Right'``.
407
        """
408
        # Top and bottom are used for the vertical 2D axis as the use of
409
        # Height vs Horizontal might be ambiguous
410
411
        if not isinstance(axis, str):
412
            raise ValueError('Axis must be a string')
413
        axis = axis[0].upper()
414
415
        # Generally, TorchIO tensors are (C, W, H, D)
416
        if axis in 'TB':  # Top, Bottom
417
            return -2
418
        else:
419
            try:
420
                index = self.orientation.index(axis)
421
            except ValueError:
422
                index = self.orientation.index(self.flip_axis(axis))
423
            # Return negative indices so that it does not matter whether we
424
            # refer to spatial dimensions or not
425
            index = -3 + index
426
            return index
427
428
    @staticmethod
429
    def flip_axis(axis: str) -> str:
430
        """Return the opposite axis label. For example, ``'L'`` -> ``'R'``.
431
432
        Args:
433
            axis: Axis label, such as ``'L'`` or ``'left'``.
434
        """
435
        labels = 'LRPAISTBDV'
436
        first = labels[::2]
437
        last = labels[1::2]
438
        flip_dict = dict(zip(first + last, last + first))
439
        axis = axis[0].upper()
440
        flipped_axis = flip_dict.get(axis)
441
        if flipped_axis is None:
442
            values = ', '.join(labels)
443
            message = f'Axis not understood. Please use one of: {values}'
444
            raise ValueError(message)
445
        return flipped_axis
446
447
    def get_spacing_string(self) -> str:
448
        strings = [f'{n:.2f}' for n in self.spacing]
449
        string = f'({", ".join(strings)})'
450
        return string
451
452
    def get_bounds(self) -> TypeBounds:
453
        """Get minimum and maximum world coordinates occupied by the image."""
454
        first_index = 3 * (-0.5,)
455
        last_index = np.array(self.spatial_shape) - 0.5
456
        first_point = apply_affine(self.affine, first_index)
457
        last_point = apply_affine(self.affine, last_index)
458
        array = np.array((first_point, last_point))
459
        bounds_x, bounds_y, bounds_z = array.T.tolist()  # type: ignore[misc]
460
        return bounds_x, bounds_y, bounds_z  # type: ignore[return-value]
461
462
    @staticmethod
463
    def _parse_single_path(
464
        path: TypePath,
465
    ) -> Path:
466
        try:
467
            path = Path(path).expanduser()
468
        except TypeError as err:
469
            message = (
470
                f'Expected type str or Path but found {path} with type'
471
                f' {type(path)} instead'
472
            )
473
            raise TypeError(message) from err
474
        except RuntimeError as err:
475
            message = f'Conversion to path not possible for variable: {path}'
476
            raise RuntimeError(message) from err
477
478
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
479
            raise FileNotFoundError(f'File not found: "{path}"')
480
        return path
481
482
    def _parse_path(
483
        self,
484
        path: TypePath | Sequence[TypePath] | None,
485
    ) -> Path | list[Path] | None:
486
        if path is None:
487
            return None
488
        elif isinstance(path, dict):
489
            # https://github.com/TorchIO-project/torchio/pull/838
490
            raise TypeError('The path argument cannot be a dictionary')
491
        elif self._is_paths_sequence(path):
492
            return [self._parse_single_path(p) for p in path]  # type: ignore[union-attr]
493
        else:
494
            return self._parse_single_path(path)  # type: ignore[arg-type]
495
496
    def _parse_tensor(
497
        self,
498
        tensor: TypeData | None,
499
        none_ok: bool = True,
500
    ) -> torch.Tensor | None:
501
        if tensor is None:
502
            if none_ok:
503
                return None
504
            else:
505
                raise RuntimeError('Input tensor cannot be None')
506
        if isinstance(tensor, np.ndarray):
507
            tensor = check_uint_to_int(tensor)
508
            tensor = torch.as_tensor(tensor)
509
        elif not isinstance(tensor, torch.Tensor):
510
            message = (
511
                'Input tensor must be a PyTorch tensor or NumPy array,'
512
                f' but type "{type(tensor)}" was found'
513
            )
514
            raise TypeError(message)
515
        ndim = tensor.ndim
516
        if ndim != 4:
517
            raise ValueError(f'Input tensor must be 4D, but it is {ndim}D')
518
        if tensor.dtype == torch.bool:
519
            tensor = tensor.to(torch.uint8)
520
        if self.check_nans and torch.isnan(tensor).any():
521
            warnings.warn('NaNs found in tensor', RuntimeWarning, stacklevel=2)
522
        return tensor
523
524
    @staticmethod
525
    def _parse_tensor_shape(tensor: torch.Tensor) -> TypeData:
526
        return ensure_4d(tensor)
527
528
    @staticmethod
529
    def _parse_affine(affine: TypeData | None) -> np.ndarray:
530
        if affine is None:
531
            return np.eye(4)
532
        if isinstance(affine, torch.Tensor):
533
            affine = affine.numpy()
534
        if not isinstance(affine, np.ndarray):
535
            bad_type = type(affine)
536
            raise TypeError(f'Affine must be a NumPy array, not {bad_type}')
537
        if affine.shape != (4, 4):
538
            bad_shape = affine.shape
539
            raise ValueError(f'Affine shape must be (4, 4), not {bad_shape}')
540
        return affine.astype(np.float64)
541
542
    @staticmethod
543
    def _is_paths_sequence(path: TypePath | Sequence[TypePath] | None) -> bool:
544
        is_not_string = not isinstance(path, str)
545
        return is_not_string and is_iterable(path)
546
547
    def _is_multipath(self) -> bool:
548
        return self._is_paths_sequence(self.path)
549
550
    def _is_dir(self) -> bool:
551
        is_sequence = self._is_multipath()
552
        if is_sequence:
553
            return False
554
        elif self.path is None:
555
            return False
556
        else:
557
            assert isinstance(self.path, Path)
558
            return self.path.is_dir()
559
560
    def load(self) -> None:
561
        r"""Load the image from disk.
562
563
        Returns:
564
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
565
            :math:`4 \times 4` affine matrix to convert voxel indices to world
566
            coordinates.
567
        """
568
        if self._loaded:
569
            return
570
571
        paths: list[Path]
572
        if self._is_multipath():
573
            paths = self.path  # type: ignore[assignment]
574
        else:
575
            paths = [self.path]  # type: ignore[list-item]
576
        tensor, affine = self.read_and_check(paths[0])
577
        tensors = [tensor]
578
        for path in paths[1:]:
579
            new_tensor, new_affine = self.read_and_check(path)
580
            if not np.array_equal(affine, new_affine):
581
                message = (
582
                    'Files have different affine matrices.'
583
                    f'\nMatrix of {paths[0]}:'
584
                    f'\n{affine}'
585
                    f'\nMatrix of {path}:'
586
                    f'\n{new_affine}'
587
                )
588
                warnings.warn(message, RuntimeWarning, stacklevel=2)
589
            if not tensor.shape[1:] == new_tensor.shape[1:]:
590
                message = (
591
                    f'Files shape do not match, found {tensor.shape}'
592
                    f'and {new_tensor.shape}'
593
                )
594
                RuntimeError(message)
595
            tensors.append(new_tensor)
596
        tensor = torch.cat(tensors)
597
        self.set_data(tensor)
598
        self.affine = affine
599
        self._loaded = True
600
601
    def unload(self) -> None:
602
        """Unload the image from memory.
603
604
        Raises:
605
            RuntimeError: If the images has not been loaded yet or if no path
606
                is available.
607
        """
608
        if not self._loaded:
609
            message = 'Image cannot be unloaded as it has not been loaded yet'
610
            raise RuntimeError(message)
611
        if self.path is None:
612
            message = (
613
                'Cannot unload image as no path is available'
614
                ' from where the image could be loaded again'
615
            )
616
            raise RuntimeError(message)
617
        self[DATA] = None
618
        self[AFFINE] = None
619
        self._loaded = False
620
621
    def read_and_check(self, path: TypePath) -> TypeDataAffine:
622
        tensor, affine = self.reader(path)
623
        # Make sure the data type is compatible with PyTorch
624
        if self.reader is not read_image and isinstance(tensor, np.ndarray):
625
            tensor = check_uint_to_int(tensor)
626
        tensor = self._parse_tensor_shape(tensor)  # type: ignore[assignment]
627
        tensor = self._parse_tensor(tensor)  # type: ignore[assignment]
628
        affine = self._parse_affine(affine)
629
        if self.check_nans and torch.isnan(tensor).any():
630
            warnings.warn(
631
                f'NaNs found in file "{path}"',
632
                RuntimeWarning,
633
                stacklevel=2,
634
            )
635
        return tensor, affine
636
637
    def save(self, path: TypePath, squeeze: bool | None = None) -> None:
638
        """Save image to disk.
639
640
        Args:
641
            path: String or instance of :class:`pathlib.Path`.
642
            squeeze: Whether to remove singleton dimensions before saving.
643
                If ``None``, the array will be squeezed if the output format is
644
                JP(E)G, PNG, BMP or TIF(F).
645
        """
646
        write_image(
647
            self.data,
648
            self.affine,
649
            path,
650
            squeeze=squeeze,
651
        )
652
653
    def is_2d(self) -> bool:
654
        return self.shape[-1] == 1
655
656
    def numpy(self) -> np.ndarray:
657
        """Get a NumPy array containing the image data."""
658
        return np.asarray(self)
659
660
    def as_sitk(self, **kwargs) -> sitk.Image:
661
        """Get the image as an instance of :class:`sitk.Image`."""
662
        return nib_to_sitk(self.data, self.affine, **kwargs)
663
664
    @classmethod
665
    def from_sitk(cls, sitk_image):
666
        """Instantiate a new TorchIO image from a :class:`sitk.Image`.
667
668
        Example:
669
            >>> import torchio as tio
670
            >>> import SimpleITK as sitk
671
            >>> sitk_image = sitk.Image(20, 30, 40, sitk.sitkUInt16)
672
            >>> tio.LabelMap.from_sitk(sitk_image)
673
            LabelMap(shape: (1, 20, 30, 40); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 93.8 KiB; dtype: torch.IntTensor)
674
            >>> sitk_image = sitk.Image((224, 224), sitk.sitkVectorFloat32, 3)
675
            >>> tio.ScalarImage.from_sitk(sitk_image)
676
            ScalarImage(shape: (3, 224, 224, 1); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 588.0 KiB; dtype: torch.FloatTensor)
677
        """
678
        tensor, affine = sitk_to_nib(sitk_image)
679
        return cls(tensor=tensor, affine=affine)
680
681
    def as_pil(self, transpose=True):
682
        """Get the image as an instance of :class:`PIL.Image`.
683
684
        .. note:: Values will be clamped to 0-255 and cast to uint8.
685
686
        .. note:: To use this method, Pillow needs to be installed:
687
            ``pip install Pillow``.
688
        """
689
        try:
690
            from PIL import Image as ImagePIL
691
        except ModuleNotFoundError as e:
692
            message = 'Please install Pillow to use Image.as_pil(): pip install Pillow'
693
            raise RuntimeError(message) from e
694
695
        self.check_is_2d()
696
        tensor = self.data
697
        if len(tensor) not in (1, 3, 4):
698
            raise NotImplementedError(
699
                'Only 1, 3 or 4 channels are supported for conversion to Pillow image'
700
            )
701
        if len(tensor) == 1:
702
            tensor = torch.cat(3 * [tensor])
703
        if transpose:
704
            tensor = tensor.permute(3, 2, 1, 0)
705
        else:
706
            tensor = tensor.permute(3, 1, 2, 0)
707
        array = tensor.clamp(0, 255).numpy()[0]
708
        return ImagePIL.fromarray(array.astype(np.uint8))
709
710
    def to_gif(
711
        self,
712
        axis: int,
713
        duration: float,  # of full gif
714
        output_path: TypePath,
715
        loop: int = 0,
716
        rescale: bool = True,
717
        optimize: bool = True,
718
        reverse: bool = False,
719
    ) -> None:
720
        """Save an animated GIF of the image.
721
722
        Args:
723
            axis: Spatial axis (0, 1 or 2).
724
            duration: Duration of the full animation in seconds.
725
            output_path: Path to the output GIF file.
726
            loop: Number of times the GIF should loop.
727
                ``0`` means that it will loop forever.
728
            rescale: Use :class:`~torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity`
729
                to rescale the intensity values to :math:`[0, 255]`.
730
            optimize: If ``True``, attempt to compress the palette by
731
                eliminating unused colors. This is only useful if the palette
732
                can be compressed to the next smaller power of 2 elements.
733
            reverse: Reverse the temporal order of frames.
734
        """
735
        from ..visualization import make_gif  # avoid circular import
736
737
        make_gif(
738
            self.data,
739
            axis,
740
            duration,
741
            output_path,
742
            loop=loop,
743
            rescale=rescale,
744
            optimize=optimize,
745
            reverse=reverse,
746
        )
747
748
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
749
        """Get image center in RAS+ or LPS+ coordinates.
750
751
        Args:
752
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
753
                the first dimension grows towards the left, etc. Otherwise, the
754
                coordinates will be in RAS+ orientation.
755
        """
756
        size = np.array(self.spatial_shape)
757
        center_index = (size - 1) / 2
758
        r, a, s = apply_affine(self.affine, center_index)
759
        if lps:
760
            return (-r, -a, s)
761
        else:
762
            return (r, a, s)
763
764
    def set_check_nans(self, check_nans: bool) -> None:
765
        self.check_nans = check_nans
766
767
    def plot(self, **kwargs) -> None:
768
        """Plot image."""
769
        if self.is_2d():
770
            self.as_pil().show()
771
        else:
772
            from ..visualization import plot_volume  # avoid circular import
773
774
            plot_volume(self, **kwargs)
775
776
    def show(self, viewer_path: TypePath | None = None) -> None:
777
        """Open the image using external software.
778
779
        Args:
780
            viewer_path: Path to the application used to view the image. If
781
                ``None``, the value of the environment variable
782
                ``SITK_SHOW_COMMAND`` will be used. If this variable is also
783
                not set, TorchIO will try to guess the location of
784
                `ITK-SNAP <http://www.itksnap.org/pmwiki/pmwiki.php>`_ and
785
                `3D Slicer <https://www.slicer.org/>`_.
786
787
        Raises:
788
            RuntimeError: If the viewer is not found.
789
        """
790
        sitk_image = self.as_sitk()
791
        image_viewer = sitk.ImageViewer()
792
        # This is so that 3D Slicer creates segmentation nodes from label maps
793
        if self.__class__.__name__ == 'LabelMap':
794
            image_viewer.SetFileExtension('.seg.nrrd')
795
        if viewer_path is not None:
796
            image_viewer.SetApplication(str(viewer_path))
797
        try:
798
            image_viewer.Execute(sitk_image)
799
        except RuntimeError as e:
800
            viewer_path = guess_external_viewer()
801
            if viewer_path is None:
802
                message = (
803
                    'No external viewer has been found. Please set the'
804
                    ' environment variable SITK_SHOW_COMMAND to a viewer of'
805
                    ' your choice'
806
                )
807
                raise RuntimeError(message) from e
808
            image_viewer.SetApplication(str(viewer_path))
809
            image_viewer.Execute(sitk_image)
810
811
    def _crop_from_slices(
812
        self,
813
        slices: TypeSlice | tuple[TypeSlice, ...],
814
    ) -> Image:
815
        from ..transforms import Crop
816
817
        slices_tuple = to_tuple(slices)  # type: ignore[assignment]
818
        cropping: list[int] = []
819
        for dim, slice_ in enumerate(slices_tuple):
820
            if isinstance(slice_, slice):
821
                pass
822
            elif slice_ is Ellipsis:
823
                message = 'Ellipsis slicing is not supported yet'
824
                raise NotImplementedError(message)
825
            elif isinstance(slice_, int):
826
                slice_ = slice(slice_, slice_ + 1)  # type: ignore[assignment]
827
            else:
828
                message = f'Slice type not understood: "{type(slice_)}"'
829
                raise TypeError(message)
830
            shape_dim = self.spatial_shape[dim]
831
            assert isinstance(slice_, slice)
832
            start, stop, step = slice_.indices(shape_dim)
833
            if step != 1:
834
                message = (
835
                    'Slicing with steps different from 1 is not supported yet.'
836
                    ' Use the Crop transform instead'
837
                )
838
                raise ValueError(message)
839
            crop_ini = start
840
            crop_fin = shape_dim - stop
841
            cropping.extend([crop_ini, crop_fin])
842
        while dim < 2:
0 ignored issues
show
introduced by
The variable dim does not seem to be defined in case the for loop on line 819 is not entered. Are you sure this can never be the case?
Loading history...
843
            cropping.extend([0, 0])
844
            dim += 1
845
        w_ini, w_fin, h_ini, h_fin, d_ini, d_fin = cropping
846
        cropping_arg = w_ini, w_fin, h_ini, h_fin, d_ini, d_fin  # making mypy happy
847
        return Crop(cropping_arg)(self)  # type: ignore[return-value]
848
849
850
class ScalarImage(Image):
851
    """Image whose pixel values represent scalars.
852
853
    Example:
854
        >>> import torch
855
        >>> import torchio as tio
856
        >>> # Loading from a file
857
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
858
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
859
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
860
        >>> data, affine = image.data, image.affine
861
        >>> affine.shape
862
        (4, 4)
863
        >>> image.data is image[tio.DATA]
864
        True
865
        >>> image.data is image.tensor
866
        True
867
        >>> type(image.data)
868
        torch.Tensor
869
870
    See :class:`~torchio.Image` for more information.
871
    """
872
873
    def __init__(self, *args, **kwargs):
874
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
875
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
876
        kwargs.update({'type': INTENSITY})
877
        super().__init__(*args, **kwargs)
878
879
    def hist(self, **kwargs) -> None:
880
        """Plot histogram."""
881
        from ..visualization import plot_histogram
882
883
        x = self.data.flatten().numpy()
884
        plot_histogram(x, **kwargs)
885
886
887
class LabelMap(Image):
888
    """Image whose pixel values represent segmentation labels.
889
890
    A sequence of paths to 3D images can be passed to create a 4D image.
891
    This is useful to create a
892
    `tissue probability map (TPM) <https://andysbrainbook.readthedocs.io/en/latest/SPM/SPM_Short_Course/SPM_04_Preprocessing/04_SPM_Segmentation.html#tissue-probability-maps>`,
893
    which contains the probability of each voxel belonging to a certain tissue type,
894
    or to create a label map with overlapping labels.
895
896
    Intensity transforms are not applied to these images.
897
898
    Nearest neighbor interpolation is always used to resample label maps,
899
    independently of the specified interpolation type in the transform
900
    instantiation.
901
902
    Example:
903
        >>> import torch
904
        >>> import torchio as tio
905
        >>> binary_tensor = torch.rand(1, 128, 128, 68) > 0.5
906
        >>> label_map = tio.LabelMap(tensor=binary_tensor)  # from a tensor
907
        >>> label_map = tio.LabelMap('t1_seg.nii.gz')  # from a file
908
        >>> # Create a 4D tissue probability map from different 3D images
909
        >>> tissues = 'gray_matter.nii.gz', 'white_matter.nii.gz', 'csf.nii.gz'
910
        >>> tpm = tio.LabelMap(tissues)
911
912
    See :class:`~torchio.Image` for more information.
913
    """
914
915
    def __init__(self, *args, **kwargs):
916
        if 'type' in kwargs and kwargs['type'] != LABEL:
917
            raise ValueError('Type of LabelMap is always torchio.LABEL')
918
        kwargs.update({'type': LABEL})
919
        super().__init__(*args, **kwargs)
920
921
    def count_nonzero(self) -> int:
922
        """Get the number of voxels that are not 0."""
923
        return int(self.data.count_nonzero().item())
924
925
    def count_labels(self) -> dict[int, int]:
926
        """Get the number of voxels in each label."""
927
        values_list = self.data.flatten().tolist()
928
        counter = Counter(values_list)
929
        counts = {label: counter[label] for label in sorted(counter)}
930
        return counts
931