Passed
Push — master ( 27308d...e28dfa )
by Fernando
01:20
created

torchio.data.image.Image._parse_tensor_shape()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
import warnings
2
from pathlib import Path
3
from collections.abc import Iterable
4
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List, Callable
5
6
import torch
7
import humanize
8
import numpy as np
9
import nibabel as nib
10
import SimpleITK as sitk
11
from deprecated import deprecated
12
13
from ..utils import get_stem
14
from ..typing import (
15
    TypeData,
16
    TypePath,
17
    TypeTripletInt,
18
    TypeTripletFloat,
19
    TypeDirection3D,
20
)
21
from ..constants import DATA, TYPE, AFFINE, PATH, STEM, INTENSITY, LABEL
22
from .io import (
23
    ensure_4d,
24
    read_image,
25
    write_image,
26
    nib_to_sitk,
27
    sitk_to_nib,
28
    check_uint_to_int,
29
    get_rotation_and_spacing_from_affine,
30
    get_sitk_metadata_from_ras_affine,
31
    read_shape,
32
    read_affine,
33
)
34
35
36
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
37
TypeBound = Tuple[float, float]
38
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
39
40
deprecation_message = (
41
    'Setting the image data with the property setter is deprecated. Use the'
42
    ' set_data() method instead'
43
)
44
45
46
class Image(dict):
47
    r"""TorchIO image.
48
49
    For information about medical image orientation, check out `NiBabel docs`_,
50
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
51
    `SimpleITK docs`_.
52
53
    Args:
54
        path: Path to a file or sequence of paths to files that can be read by
55
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
56
            DICOM files. If :attr:`tensor` is given, the data in
57
            :attr:`path` will not be read.
58
            If a sequence of paths is given, data
59
            will be concatenated on the channel dimension so spatial
60
            dimensions must match.
61
        type: Type of image, such as :attr:`torchio.INTENSITY` or
62
            :attr:`torchio.LABEL`. This will be used by the transforms to
63
            decide whether to apply an operation, or which interpolation to use
64
            when resampling. For example, `preprocessing`_ and `augmentation`_
65
            intensity transforms will only be applied to images with type
66
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
67
            all types, and nearest neighbor interpolation is always used to
68
            resample images with type :attr:`torchio.LABEL`.
69
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
70
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
71
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
72
            :class:`torch.Tensor` or NumPy array with dimensions
73
            :math:`(C, W, H, D)`.
74
        affine: :math:`4 \times 4` matrix to convert voxel coordinates to world
75
            coordinates. If ``None``, an identity matrix will be used. See the
76
            `NiBabel docs on coordinates`_ for more information.
77
        check_nans: If ``True``, issues a warning if NaNs are found
78
            in the image. If ``False``, images will not be checked for the
79
            presence of NaNs.
80
        channels_last: If ``True``, the read tensor will be permuted so the
81
            last dimension becomes the first. This is useful, e.g., when
82
            NIfTI images have been saved with the channels dimension being the
83
            fourth instead of the fifth.
84
        reader: Callable object that takes a path and returns a 4D tensor and a
85
            2D, :math:`4 \times 4` affine matrix. This can be used if your data
86
            is saved in a custom format, such as ``.npy`` (see example below).
87
            If the affine matrix is ``None``, an identity matrix will be used.
88
        **kwargs: Items that will be added to the image dictionary, e.g.
89
            acquisition parameters.
90
91
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
92
    when needed.
93
94
    Example:
95
        >>> import torchio as tio
96
        >>> import numpy as np
97
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
98
        >>> image  # not loaded yet
99
        ScalarImage(path: t1.nii.gz; type: intensity)
100
        >>> times_two = 2 * image.data  # data is loaded and cached here
101
        >>> image
102
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
103
        >>> image.save('doubled_image.nii.gz')
104
        >>> numpy_reader = lambda path: np.load(path), np.eye(4)
105
        >>> image = tio.ScalarImage('t1.npy', reader=numpy_reader)
106
107
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
108
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
109
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
110
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
111
    .. _NiBabel docs on coordinates: https://nipy.org/nibabel/coordinate_systems.html#the-affine-matrix-as-a-transformation-between-spaces
112
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
113
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
114
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
115
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
116
    """
117
    def __init__(
118
            self,
119
            path: Union[TypePath, Sequence[TypePath], None] = None,
120
            type: str = None,
121
            tensor: Optional[TypeData] = None,
122
            affine: Optional[TypeData] = None,
123
            check_nans: bool = False,  # removed by ITK by default
124
            channels_last: bool = False,
125
            reader: Callable = read_image,
126
            **kwargs: Dict[str, Any],
127
            ):
128
        self.check_nans = check_nans
129
        self.channels_last = channels_last
130
        self.reader = reader
131
132
        if type is None:
133
            warnings.warn(
134
                'Not specifying the image type is deprecated and will be'
135
                ' mandatory in the future. You can probably use tio.ScalarImage'
136
                ' or tio.LabelMap instead',
137
            )
138
            type = INTENSITY
139
140
        if path is None and tensor is None:
141
            raise ValueError('A value for path or tensor must be given')
142
        self._loaded = False
143
144
        tensor = self._parse_tensor(tensor)
145
        affine = self._parse_affine(affine)
146
        if tensor is not None:
147
            self.set_data(tensor)
148
            self.affine = affine
149
            self._loaded = True
150
        for key in PROTECTED_KEYS:
151
            if key in kwargs:
152
                message = f'Key "{key}" is reserved. Use a different one'
153
                raise ValueError(message)
154
155
        super().__init__(**kwargs)
156
        self.path = self._parse_path(path)
157
158
        self[PATH] = '' if self.path is None else str(self.path)
159
        self[STEM] = '' if self.path is None else get_stem(self.path)
160
        self[TYPE] = type
161
162
    def __repr__(self):
163
        properties = []
164
        properties.extend([
165
            f'shape: {self.shape}',
166
            f'spacing: {self.get_spacing_string()}',
167
            f'orientation: {"".join(self.orientation)}+',
168
        ])
169
        if self._loaded:
170
            properties.append(f'dtype: {self.data.type()}')
171
            properties.append(f'memory: {humanize.naturalsize(self.memory, binary=True)}')
172
        else:
173
            properties.append(f'path: "{self.path}"')
174
175
        properties = '; '.join(properties)
176
        string = f'{self.__class__.__name__}({properties})'
177
        return string
178
179
    def __getitem__(self, item):
180
        if item in (DATA, AFFINE):
181
            if item not in self:
182
                self.load()
183
        return super().__getitem__(item)
184
185
    def __array__(self):
186
        return self.data.numpy()
187
188
    def __copy__(self):
189
        kwargs = dict(
190
            tensor=self.data,
191
            affine=self.affine,
192
            type=self.type,
193
            path=self.path,
194
        )
195
        for key, value in self.items():
196
            if key in PROTECTED_KEYS: continue
197
            kwargs[key] = value  # should I copy? deepcopy?
198
        return self.__class__(**kwargs)
199
200
    @property
201
    def data(self) -> torch.Tensor:
202
        """Tensor data. Same as :class:`Image.tensor`."""
203
        return self[DATA]
204
205
    @data.setter  # type: ignore
206
    @deprecated(version='0.18.16', reason=deprecation_message)
207
    def data(self, tensor: TypeData):
208
        self.set_data(tensor)
209
210
    def set_data(self, tensor: TypeData):
211
        """Store a 4D tensor in the :attr:`data` key and attribute.
212
213
        Args:
214
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
215
        """
216
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
217
218
    @property
219
    def tensor(self) -> torch.Tensor:
220
        """Tensor data. Same as :class:`Image.data`."""
221
        return self.data
222
223
    @property
224
    def affine(self) -> np.ndarray:
225
        """Affine matrix to transform voxel indices into world coordinates."""
226
        # If path is a dir (probably DICOM), just load the data
227
        # Same if it's a list of paths (used to create a 4D image)
228
        if self._loaded or (isinstance(self.path, Path) and self.path.is_dir()):
229
            affine = self[AFFINE]
230
        else:
231
            affine = read_affine(self.path)
232
        return affine
233
234
    @affine.setter
235
    def affine(self, matrix):
236
        self[AFFINE] = self._parse_affine(matrix)
237
238
    @property
239
    def type(self) -> str:
240
        return self[TYPE]
241
242
    @property
243
    def shape(self) -> Tuple[int, int, int, int]:
244
        """Tensor shape as :math:`(C, W, H, D)`."""
245
        custom_reader = self.reader is not read_image
246
        multipath = not isinstance(self.path, (str, Path))
247
        if self._loaded or custom_reader or multipath or self.path.is_dir():
248
            shape = tuple(self.data.shape)
249
        else:
250
            shape = read_shape(self.path)
251
        return shape
252
253
    @property
254
    def spatial_shape(self) -> TypeTripletInt:
255
        """Tensor spatial shape as :math:`(W, H, D)`."""
256
        return self.shape[1:]
257
258
    def check_is_2d(self) -> None:
259
        if not self.is_2d():
260
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
261
            raise RuntimeError(message)
262
263
    @property
264
    def height(self) -> int:
265
        """Image height, if 2D."""
266
        self.check_is_2d()
267
        return self.spatial_shape[1]
268
269
    @property
270
    def width(self) -> int:
271
        """Image width, if 2D."""
272
        self.check_is_2d()
273
        return self.spatial_shape[0]
274
275
    @property
276
    def orientation(self) -> Tuple[str, str, str]:
277
        """Orientation codes."""
278
        return nib.aff2axcodes(self.affine)
279
280
    @property
281
    def direction(self) -> TypeDirection3D:
282
        _, _, direction = get_sitk_metadata_from_ras_affine(
283
            self.affine, lps=False)
284
        return direction
285
286
    @property
287
    def spacing(self) -> Tuple[float, float, float]:
288
        """Voxel spacing in mm."""
289
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
290
        return tuple(spacing)
291
292
    @property
293
    def origin(self) -> Tuple[float, float, float]:
294
        """Center of first voxel in array, in mm."""
295
        return tuple(self.affine[:3, 3])
296
297
    @property
298
    def itemsize(self):
299
        """Element size of the data type."""
300
        return self.data.element_size()
301
302
    @property
303
    def memory(self) -> float:
304
        """Number of Bytes that the tensor takes in the RAM."""
305
        return np.prod(self.shape) * self.itemsize
306
307
    @property
308
    def bounds(self) -> np.ndarray:
309
        """Position of centers of voxels in smallest and largest coordinates."""
310
        ini = 0, 0, 0
311
        fin = np.array(self.spatial_shape) - 1
312
        point_ini = nib.affines.apply_affine(self.affine, ini)
313
        point_fin = nib.affines.apply_affine(self.affine, fin)
314
        return np.array((point_ini, point_fin))
315
316
    @property
317
    def num_channels(self) -> int:
318
        """Get the number of channels in the associated 4D tensor."""
319
        return len(self.data)
320
321
    def axis_name_to_index(self, axis: str) -> int:
322
        """Convert an axis name to an axis index.
323
324
        Args:
325
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
326
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
327
                versions and first letters are also valid, as only the first
328
                letter will be used.
329
330
        .. note:: If you are working with animals, you should probably use
331
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
332
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
333
            respectively.
334
335
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
336
            ``'Left'`` and ``'Right'``.
337
        """
338
        # Top and bottom are used for the vertical 2D axis as the use of
339
        # Height vs Horizontal might be ambiguous
340
341
        if not isinstance(axis, str):
342
            raise ValueError('Axis must be a string')
343
        axis = axis[0].upper()
344
345
        # Generally, TorchIO tensors are (C, W, H, D)
346
        if axis in 'TB':  # Top, Bottom
347
            return -2
348
        else:
349
            try:
350
                index = self.orientation.index(axis)
351
            except ValueError:
352
                index = self.orientation.index(self.flip_axis(axis))
353
            # Return negative indices so that it does not matter whether we
354
            # refer to spatial dimensions or not
355
            index = -3 + index
356
            return index
357
358
    # flake8: noqa: E701
359
    @staticmethod
360
    def flip_axis(axis: str) -> str:
361
        if axis == 'R': flipped_axis = 'L'
362
        elif axis == 'L': flipped_axis = 'R'
363
        elif axis == 'A': flipped_axis = 'P'
364
        elif axis == 'P': flipped_axis = 'A'
365
        elif axis == 'I': flipped_axis = 'S'
366
        elif axis == 'S': flipped_axis = 'I'
367
        elif axis == 'T': flipped_axis = 'B'  # top / bottom
368
        elif axis == 'B': flipped_axis = 'T'
369
        else:
370
            values = ', '.join('LRPAISTB')
371
            message = f'Axis not understood. Please use one of: {values}'
372
            raise ValueError(message)
373
        return flipped_axis
374
375
    def get_spacing_string(self) -> str:
376
        strings = [f'{n:.2f}' for n in self.spacing]
377
        string = f'({", ".join(strings)})'
378
        return string
379
380
    def get_bounds(self) -> TypeBounds:
381
        """Get minimum and maximum world coordinates occupied by the image."""
382
        first_index = 3 * (-0.5,)
383
        last_index = np.array(self.spatial_shape) - 0.5
384
        first_point = nib.affines.apply_affine(self.affine, first_index)
385
        last_point = nib.affines.apply_affine(self.affine, last_index)
386
        array = np.array((first_point, last_point))
387
        bounds_x, bounds_y, bounds_z = array.T.tolist()
388
        return bounds_x, bounds_y, bounds_z
389
390
    @staticmethod
391
    def _parse_single_path(
392
            path: TypePath
393
            ) -> Path:
394
        try:
395
            path = Path(path).expanduser()
396
        except TypeError:
397
            message = (
398
                f'Expected type str or Path but found {path} with type'
399
                f' {type(path)} instead'
400
            )
401
            raise TypeError(message)
402
        except RuntimeError:
403
            message = (
404
                f'Conversion to path not possible for variable: {path}'
405
            )
406
            raise RuntimeError(message)
407
408
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
409
            raise FileNotFoundError(f'File not found: "{path}"')
410
        return path
411
412
    def _parse_path(
413
            self,
414
            path: Union[TypePath, Sequence[TypePath], None]
415
            ) -> Optional[Union[Path, List[Path]]]:
416
        if path is None:
417
            return None
418
        if isinstance(path, Iterable) and not isinstance(path, str):
419
            return [self._parse_single_path(p) for p in path]
420
        else:
421
            return self._parse_single_path(path)
422
423
    def _parse_tensor(
424
            self,
425
            tensor: Optional[TypeData],
426
            none_ok: bool = True,
427
            ) -> Optional[torch.Tensor]:
428
        if tensor is None:
429
            if none_ok:
430
                return None
431
            else:
432
                raise RuntimeError('Input tensor cannot be None')
433
        if isinstance(tensor, np.ndarray):
434
            tensor = check_uint_to_int(tensor)
435
            tensor = torch.as_tensor(tensor)
436
        elif not isinstance(tensor, torch.Tensor):
437
            message = (
438
                'Input tensor must be a PyTorch tensor or NumPy array,'
439
                f' but type "{type(tensor)}" was found'
440
            )
441
            raise TypeError(message)
442
        ndim = tensor.ndim
443
        if ndim != 4:
444
            raise ValueError(f'Input tensor must be 4D, but it is {ndim}D')
445
        if tensor.dtype == torch.bool:
446
            tensor = tensor.to(torch.uint8)
447
        if self.check_nans and torch.isnan(tensor).any():
448
            warnings.warn(f'NaNs found in tensor', RuntimeWarning)
449
        return tensor
450
451
    @staticmethod
452
    def _parse_tensor_shape(tensor: torch.Tensor) -> TypeData:
453
        return ensure_4d(tensor)
454
455
    @staticmethod
456
    def _parse_affine(affine: Optional[TypeData]) -> np.ndarray:
457
        if affine is None:
458
            return np.eye(4)
459
        if isinstance(affine, torch.Tensor):
460
            affine = affine.numpy()
461
        if not isinstance(affine, np.ndarray):
462
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
463
        if affine.shape != (4, 4):
464
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
465
        return affine.astype(np.float64)
466
467
    def load(self) -> None:
468
        r"""Load the image from disk.
469
470
        Returns:
471
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
472
            :math:`4 \times 4` affine matrix to convert voxel indices to world
473
            coordinates.
474
        """
475
        if self._loaded:
476
            return
477
        paths = self.path if isinstance(self.path, list) else [self.path]
478
        tensor, affine = self.read_and_check(paths[0])
479
        tensors = [tensor]
480
        for path in paths[1:]:
481
            new_tensor, new_affine = self.read_and_check(path)
482
            if not np.array_equal(affine, new_affine):
483
                message = (
484
                    'Files have different affine matrices.'
485
                    f'\nMatrix of {paths[0]}:'
486
                    f'\n{affine}'
487
                    f'\nMatrix of {path}:'
488
                    f'\n{new_affine}'
489
                )
490
                warnings.warn(message, RuntimeWarning)
491
            if not tensor.shape[1:] == new_tensor.shape[1:]:
492
                message = (
493
                    f'Files shape do not match, found {tensor.shape}'
494
                    f'and {new_tensor.shape}'
495
                )
496
                RuntimeError(message)
497
            tensors.append(new_tensor)
498
        tensor = torch.cat(tensors)
499
        self.set_data(tensor)
500
        self.affine = affine
501
        self._loaded = True
502
503
    def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
504
        tensor, affine = self.reader(path)
505
        tensor = self._parse_tensor_shape(tensor)
506
        tensor = self._parse_tensor(tensor)
507
        affine = self._parse_affine(affine)
508
        if self.channels_last:
509
            tensor = tensor.permute(3, 0, 1, 2)
510
        if self.check_nans and torch.isnan(tensor).any():
511
            warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
512
        return tensor, affine
513
514
    def save(self, path: TypePath, squeeze: Optional[bool] = None) -> None:
515
        """Save image to disk.
516
517
        Args:
518
            path: String or instance of :class:`pathlib.Path`.
519
            squeeze: Whether to remove singleton dimensions before saving.
520
                If ``None``, the array will be squeezed if the output format is
521
                JP(E)G, PNG, BMP or TIF(F).
522
        """
523
        write_image(
524
            self.data,
525
            self.affine,
526
            path,
527
            squeeze=squeeze,
528
        )
529
530
    def is_2d(self) -> bool:
531
        return self.shape[-1] == 1
532
533
    def numpy(self) -> np.ndarray:
534
        """Get a NumPy array containing the image data."""
535
        return np.asarray(self)
536
537
    def as_sitk(self, **kwargs) -> sitk.Image:
538
        """Get the image as an instance of :class:`sitk.Image`."""
539
        return nib_to_sitk(self.data, self.affine, **kwargs)
540
541
    @classmethod
542
    def from_sitk(cls, sitk_image):
543
        """Instantiate a new TorchIO image from a :class:`sitk.Image`.
544
545
        Example:
546
            >>> import torchio as tio
547
            >>> import SimpleITK as sitk
548
            >>> sitk_image = sitk.Image(20, 30, 40, sitk.sitkUInt16)
549
            >>> tio.LabelMap.from_sitk(sitk_image)
550
            LabelMap(shape: (1, 20, 30, 40); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 93.8 KiB; dtype: torch.IntTensor)
551
            >>> sitk_image = sitk.Image((224, 224), sitk.sitkVectorFloat32, 3)
552
            >>> tio.ScalarImage.from_sitk(sitk_image)
553
            ScalarImage(shape: (3, 224, 224, 1); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 588.0 KiB; dtype: torch.FloatTensor)
554
        """
555
        tensor, affine = sitk_to_nib(sitk_image)
556
        return cls(tensor=tensor, affine=affine)
557
558
    def as_pil(self, transpose=True):
559
        """Get the image as an instance of :class:`PIL.Image`.
560
561
        .. note:: Values will be clamped to 0-255 and cast to uint8.
562
        .. note:: To use this method, `Pillow` needs to be installed:
563
            `pip install Pillow`.
564
        """
565
        try:
566
            from PIL import Image as ImagePIL
567
        except ModuleNotFoundError as e:
568
            message = (
569
                'Please install Pillow to use Image.as_pil():'
570
                ' pip install Pillow'
571
            )
572
            raise RuntimeError(message) from e
573
574
        self.check_is_2d()
575
        tensor = self.data
576
        if len(tensor) == 1:
577
            tensor = torch.cat(3 * [tensor])
578
        if len(tensor) != 3:
579
            raise RuntimeError('The image must have 1 or 3 channels')
580
        if transpose:
581
            tensor = tensor.permute(3, 2, 1, 0)
582
        else:
583
            tensor = tensor.permute(3, 1, 2, 0)
584
        array = tensor.clamp(0, 255).numpy()[0]
585
        return ImagePIL.fromarray(array.astype(np.uint8))
586
587
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
588
        """Get image center in RAS+ or LPS+ coordinates.
589
590
        Args:
591
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
592
                the first dimension grows towards the left, etc. Otherwise, the
593
                coordinates will be in RAS+ orientation.
594
        """
595
        size = np.array(self.spatial_shape)
596
        center_index = (size - 1) / 2
597
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
598
        if lps:
599
            return (-r, -a, s)
600
        else:
601
            return (r, a, s)
602
603
    def set_check_nans(self, check_nans: bool) -> None:
604
        self.check_nans = check_nans
605
606
    def plot(self, **kwargs) -> None:
607
        """Plot image."""
608
        if self.is_2d():
609
            self.as_pil().show()
610
        else:
611
            from ..visualization import plot_volume  # avoid circular import
612
            plot_volume(self, **kwargs)
613
614
615
class ScalarImage(Image):
616
    """Image whose pixel values represent scalars.
617
618
    Example:
619
        >>> import torch
620
        >>> import torchio as tio
621
        >>> # Loading from a file
622
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
623
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
624
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
625
        >>> data, affine = image.data, image.affine
626
        >>> affine.shape
627
        (4, 4)
628
        >>> image.data is image[tio.DATA]
629
        True
630
        >>> image.data is image.tensor
631
        True
632
        >>> type(image.data)
633
        torch.Tensor
634
635
    See :class:`~torchio.Image` for more information.
636
    """
637
    def __init__(self, *args, **kwargs):
638
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
639
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
640
        kwargs.update({'type': INTENSITY})
641
        super().__init__(*args, **kwargs)
642
643
644
class LabelMap(Image):
645
    """Image whose pixel values represent categorical labels.
646
647
    Example:
648
        >>> import torch
649
        >>> import torchio as tio
650
        >>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5)
651
        >>> labels = tio.LabelMap('t1_seg.nii.gz')  # loading from a file
652
        >>> tpm = tio.LabelMap(                     # loading from files
653
        ...     'gray_matter.nii.gz',
654
        ...     'white_matter.nii.gz',
655
        ...     'csf.nii.gz',
656
        ... )
657
658
    Intensity transforms are not applied to these images.
659
660
    Nearest neighbor interpolation is always used to resample label maps,
661
    independently of the specified interpolation type in the transform
662
    instantiation.
663
664
    See :class:`~torchio.Image` for more information.
665
    """
666
    def __init__(self, *args, **kwargs):
667
        if 'type' in kwargs and kwargs['type'] != LABEL:
668
            raise ValueError('Type of LabelMap is always torchio.LABEL')
669
        kwargs.update({'type': LABEL})
670
        super().__init__(*args, **kwargs)
671