Passed
Push — master ( dc38a5...a905db )
by Fernando
01:48
created

torchio.data.image.Image.origin()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
import warnings
2
from pathlib import Path
3
from collections.abc import Iterable
4
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List, Callable
5
6
import torch
7
import humanize
8
import numpy as np
9
import nibabel as nib
10
import SimpleITK as sitk
11
from deprecated import deprecated
12
13
from ..utils import get_stem
14
from ..typing import TypeData, TypePath, TypeTripletInt, TypeTripletFloat
15
from ..constants import DATA, TYPE, AFFINE, PATH, STEM, INTENSITY, LABEL
16
from .io import (
17
    ensure_4d,
18
    read_image,
19
    write_image,
20
    nib_to_sitk,
21
    sitk_to_nib,
22
    check_uint_to_int,
23
    get_rotation_and_spacing_from_affine,
24
)
25
26
27
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
28
TypeBound = Tuple[float, float]
29
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
30
31
deprecation_message = (
32
    'Setting the image data with the property setter is deprecated. Use the'
33
    ' set_data() method instead'
34
)
35
36
37
class Image(dict):
38
    r"""TorchIO image.
39
40
    For information about medical image orientation, check out `NiBabel docs`_,
41
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
42
    `SimpleITK docs`_.
43
44
    Args:
45
        path: Path to a file or sequence of paths to files that can be read by
46
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
47
            DICOM files. If :attr:`tensor` is given, the data in
48
            :attr:`path` will not be read.
49
            If a sequence of paths is given, data
50
            will be concatenated on the channel dimension so spatial
51
            dimensions must match.
52
        type: Type of image, such as :attr:`torchio.INTENSITY` or
53
            :attr:`torchio.LABEL`. This will be used by the transforms to
54
            decide whether to apply an operation, or which interpolation to use
55
            when resampling. For example, `preprocessing`_ and `augmentation`_
56
            intensity transforms will only be applied to images with type
57
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
58
            all types, and nearest neighbor interpolation is always used to
59
            resample images with type :attr:`torchio.LABEL`.
60
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
61
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
62
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
63
            :class:`torch.Tensor` or NumPy array with dimensions
64
            :math:`(C, W, H, D)`.
65
        affine: :math:`4 \times 4` matrix to convert voxel coordinates to world
66
            coordinates. If ``None``, an identity matrix will be used. See the
67
            `NiBabel docs on coordinates`_ for more information.
68
        check_nans: If ``True``, issues a warning if NaNs are found
69
            in the image. If ``False``, images will not be checked for the
70
            presence of NaNs.
71
        channels_last: If ``True``, the read tensor will be permuted so the
72
            last dimension becomes the first. This is useful, e.g., when
73
            NIfTI images have been saved with the channels dimension being the
74
            fourth instead of the fifth.
75
        reader: Callable object that takes a path and returns a 4D tensor and a
76
            2D, :math:`4 \times 4` affine matrix. This can be used if your data
77
            is saved in a custom format, such as ``.npy`` (see example below).
78
            If the affine matrix is ``None``, an identity matrix will be used.
79
        **kwargs: Items that will be added to the image dictionary, e.g.
80
            acquisition parameters.
81
82
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
83
    when needed.
84
85
    Example:
86
        >>> import torchio as tio
87
        >>> import numpy as np
88
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
89
        >>> image  # not loaded yet
90
        ScalarImage(path: t1.nii.gz; type: intensity)
91
        >>> times_two = 2 * image.data  # data is loaded and cached here
92
        >>> image
93
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
94
        >>> image.save('doubled_image.nii.gz')
95
        >>> numpy_reader = lambda path: np.load(path), np.eye(4)
96
        >>> image = tio.ScalarImage('t1.npy', reader=numpy_reader)
97
98
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
99
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
100
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
101
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
102
    .. _NiBabel docs on coordinates: https://nipy.org/nibabel/coordinate_systems.html#the-affine-matrix-as-a-transformation-between-spaces
103
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
104
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
105
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
106
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
107
    """
108
    def __init__(
109
            self,
110
            path: Union[TypePath, Sequence[TypePath], None] = None,
111
            type: str = None,
112
            tensor: Optional[TypeData] = None,
113
            affine: Optional[TypeData] = None,
114
            check_nans: bool = False,  # removed by ITK by default
115
            channels_last: bool = False,
116
            reader: Callable = read_image,
117
            **kwargs: Dict[str, Any],
118
            ):
119
        self.check_nans = check_nans
120
        self.channels_last = channels_last
121
        self.reader = reader
122
123
        if type is None:
124
            warnings.warn(
125
                'Not specifying the image type is deprecated and will be'
126
                ' mandatory in the future. You can probably use tio.ScalarImage'
127
                ' or tio.LabelMap instead',
128
            )
129
            type = INTENSITY
130
131
        if path is None and tensor is None:
132
            raise ValueError('A value for path or tensor must be given')
133
        self._loaded = False
134
135
        tensor = self._parse_tensor(tensor)
136
        affine = self._parse_affine(affine)
137
        if tensor is not None:
138
            self.set_data(tensor)
139
            self.affine = affine
140
            self._loaded = True
141
        for key in PROTECTED_KEYS:
142
            if key in kwargs:
143
                message = f'Key "{key}" is reserved. Use a different one'
144
                raise ValueError(message)
145
146
        super().__init__(**kwargs)
147
        self.path = self._parse_path(path)
148
149
        self[PATH] = '' if self.path is None else str(self.path)
150
        self[STEM] = '' if self.path is None else get_stem(self.path)
151
        self[TYPE] = type
152
153
    def __repr__(self):
154
        properties = []
155
        if self._loaded:
156
            properties.extend([
157
                f'shape: {self.shape}',
158
                f'spacing: {self.get_spacing_string()}',
159
                f'orientation: {"".join(self.orientation)}+',
160
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
161
            ])
162
        else:
163
            properties.append(f'path: "{self.path}"')
164
        if self._loaded:
165
            properties.append(f'dtype: {self.data.type()}')
166
        properties = '; '.join(properties)
167
        string = f'{self.__class__.__name__}({properties})'
168
        return string
169
170
    def __getitem__(self, item):
171
        if item in (DATA, AFFINE):
172
            if item not in self:
173
                self.load()
174
        return super().__getitem__(item)
175
176
    def __array__(self):
177
        return self.data.numpy()
178
179
    def __copy__(self):
180
        kwargs = dict(
181
            tensor=self.data,
182
            affine=self.affine,
183
            type=self.type,
184
            path=self.path,
185
        )
186
        for key, value in self.items():
187
            if key in PROTECTED_KEYS: continue
188
            kwargs[key] = value  # should I copy? deepcopy?
189
        return self.__class__(**kwargs)
190
191
    @property
192
    def data(self) -> torch.Tensor:
193
        """Tensor data. Same as :class:`Image.tensor`."""
194
        return self[DATA]
195
196
    @data.setter
197
    @deprecated(version='0.18.16', reason=deprecation_message)
198
    def data(self, tensor: TypeData):
199
        self.set_data(tensor)
200
201
    def set_data(self, tensor: TypeData):
202
        """Store a 4D tensor in the :attr:`data` key and attribute.
203
204
        Args:
205
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
206
        """
207
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
208
209
    @property
210
    def tensor(self) -> torch.Tensor:
211
        """Tensor data. Same as :class:`Image.data`."""
212
        return self.data
213
214
    @property
215
    def affine(self) -> np.ndarray:
216
        """Affine matrix to transform voxel indices into world coordinates."""
217
        return self[AFFINE]
218
219
    @affine.setter
220
    def affine(self, matrix):
221
        self[AFFINE] = self._parse_affine(matrix)
222
223
    @property
224
    def type(self) -> str:
225
        return self[TYPE]
226
227
    @property
228
    def shape(self) -> Tuple[int, int, int, int]:
229
        """Tensor shape as :math:`(C, W, H, D)`."""
230
        return tuple(self.data.shape)
231
232
    @property
233
    def spatial_shape(self) -> TypeTripletInt:
234
        """Tensor spatial shape as :math:`(W, H, D)`."""
235
        return self.shape[1:]
236
237
    def check_is_2d(self) -> None:
238
        if not self.is_2d():
239
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
240
            raise RuntimeError(message)
241
242
    @property
243
    def height(self) -> int:
244
        """Image height, if 2D."""
245
        self.check_is_2d()
246
        return self.spatial_shape[1]
247
248
    @property
249
    def width(self) -> int:
250
        """Image width, if 2D."""
251
        self.check_is_2d()
252
        return self.spatial_shape[0]
253
254
    @property
255
    def orientation(self) -> Tuple[str, str, str]:
256
        """Orientation codes."""
257
        return nib.aff2axcodes(self.affine)
258
259
    @property
260
    def spacing(self) -> Tuple[float, float, float]:
261
        """Voxel spacing in mm."""
262
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
263
        return tuple(spacing)
264
265
    @property
266
    def origin(self) -> Tuple[float, float, float]:
267
        """Center of first voxel in array, in mm."""
268
        return tuple(self.affine[:3, 3])
269
270
    @property
271
    def itemsize(self):
272
        """Element size of the data type."""
273
        return self.data.element_size()
274
275
    @property
276
    def memory(self) -> float:
277
        """Number of Bytes that the tensor takes in the RAM."""
278
        return np.prod(self.shape) * self.itemsize
279
280
    @property
281
    def bounds(self) -> np.ndarray:
282
        """Position of centers of voxels in smallest and largest coordinates."""
283
        ini = 0, 0, 0
284
        fin = np.array(self.spatial_shape) - 1
285
        point_ini = nib.affines.apply_affine(self.affine, ini)
286
        point_fin = nib.affines.apply_affine(self.affine, fin)
287
        return np.array((point_ini, point_fin))
288
289
    @property
290
    def num_channels(self) -> int:
291
        """Get the number of channels in the associated 4D tensor."""
292
        return len(self.data)
293
294
    def axis_name_to_index(self, axis: str) -> int:
295
        """Convert an axis name to an axis index.
296
297
        Args:
298
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
299
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
300
                versions and first letters are also valid, as only the first
301
                letter will be used.
302
303
        .. note:: If you are working with animals, you should probably use
304
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
305
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
306
            respectively.
307
308
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
309
            ``'Left'`` and ``'Right'``.
310
        """
311
        # Top and bottom are used for the vertical 2D axis as the use of
312
        # Height vs Horizontal might be ambiguous
313
314
        if not isinstance(axis, str):
315
            raise ValueError('Axis must be a string')
316
        axis = axis[0].upper()
317
318
        # Generally, TorchIO tensors are (C, W, H, D)
319
        if axis in 'TB':  # Top, Bottom
320
            return -2
321
        else:
322
            try:
323
                index = self.orientation.index(axis)
324
            except ValueError:
325
                index = self.orientation.index(self.flip_axis(axis))
326
            # Return negative indices so that it does not matter whether we
327
            # refer to spatial dimensions or not
328
            index = -3 + index
329
            return index
330
331
    # flake8: noqa: E701
332
    @staticmethod
333
    def flip_axis(axis: str) -> str:
334
        if axis == 'R': flipped_axis = 'L'
335
        elif axis == 'L': flipped_axis = 'R'
336
        elif axis == 'A': flipped_axis = 'P'
337
        elif axis == 'P': flipped_axis = 'A'
338
        elif axis == 'I': flipped_axis = 'S'
339
        elif axis == 'S': flipped_axis = 'I'
340
        elif axis == 'T': flipped_axis = 'B'
341
        elif axis == 'B': flipped_axis = 'T'
342
        else:
343
            values = ', '.join('LRPAISTB')
344
            message = f'Axis not understood. Please use one of: {values}'
345
            raise ValueError(message)
346
        return flipped_axis
347
348
    def get_spacing_string(self) -> str:
349
        strings = [f'{n:.2f}' for n in self.spacing]
350
        string = f'({", ".join(strings)})'
351
        return string
352
353
    def get_bounds(self) -> TypeBounds:
354
        """Get minimum and maximum world coordinates occupied by the image."""
355
        first_index = 3 * (-0.5,)
356
        last_index = np.array(self.spatial_shape) - 0.5
357
        first_point = nib.affines.apply_affine(self.affine, first_index)
358
        last_point = nib.affines.apply_affine(self.affine, last_index)
359
        array = np.array((first_point, last_point))
360
        bounds_x, bounds_y, bounds_z = array.T.tolist()
361
        return bounds_x, bounds_y, bounds_z
362
363
    @staticmethod
364
    def _parse_single_path(
365
            path: TypePath
366
            ) -> Path:
367
        try:
368
            path = Path(path).expanduser()
369
        except TypeError:
370
            message = (
371
                f'Expected type str or Path but found {path} with type'
372
                f' {type(path)} instead'
373
            )
374
            raise TypeError(message)
375
        except RuntimeError:
376
            message = (
377
                f'Conversion to path not possible for variable: {path}'
378
            )
379
            raise RuntimeError(message)
380
381
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
382
            raise FileNotFoundError(f'File not found: "{path}"')
383
        return path
384
385
    def _parse_path(
386
            self,
387
            path: Union[TypePath, Sequence[TypePath]]
388
            ) -> Optional[Union[Path, List[Path]]]:
389
        if path is None:
390
            return None
391
        if isinstance(path, Iterable) and not isinstance(path, str):
392
            return [self._parse_single_path(p) for p in path]
393
        else:
394
            return self._parse_single_path(path)
395
396
    def _parse_tensor(
397
            self,
398
            tensor: TypeData,
399
            none_ok: bool = True,
400
            ) -> torch.Tensor:
401
        if tensor is None:
402
            if none_ok:
403
                return None
404
            else:
405
                raise RuntimeError('Input tensor cannot be None')
406
        if isinstance(tensor, np.ndarray):
407
            tensor = check_uint_to_int(tensor)
408
            tensor = torch.as_tensor(tensor)
409
        elif not isinstance(tensor, torch.Tensor):
410
            message = (
411
                'Input tensor must be a PyTorch tensor or NumPy array,'
412
                f' but type "{type(tensor)}" was found'
413
            )
414
            raise TypeError(message)
415
        ndim = tensor.ndim
416
        if ndim != 4:
417
            raise ValueError(f'Input tensor must be 4D, but it is {ndim}D')
418
        if tensor.dtype == torch.bool:
419
            tensor = tensor.to(torch.uint8)
420
        if self.check_nans and torch.isnan(tensor).any():
421
            warnings.warn(f'NaNs found in tensor', RuntimeWarning)
422
        return tensor
423
424
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
425
        return ensure_4d(tensor)
426
427
    @staticmethod
428
    def _parse_affine(affine: TypeData) -> np.ndarray:
429
        if affine is None:
430
            return np.eye(4)
431
        if isinstance(affine, torch.Tensor):
432
            affine = affine.numpy()
433
        if not isinstance(affine, np.ndarray):
434
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
435
        if affine.shape != (4, 4):
436
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
437
        return affine.astype(np.float64)
438
439
    def load(self) -> None:
440
        r"""Load the image from disk.
441
442
        Returns:
443
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
444
            :math:`4 \times 4` affine matrix to convert voxel indices to world
445
            coordinates.
446
        """
447
        if self._loaded:
448
            return
449
        paths = self.path if isinstance(self.path, list) else [self.path]
450
        tensor, affine = self.read_and_check(paths[0])
451
        tensors = [tensor]
452
        for path in paths[1:]:
453
            new_tensor, new_affine = self.read_and_check(path)
454
            if not np.array_equal(affine, new_affine):
455
                message = (
456
                    'Files have different affine matrices.'
457
                    f'\nMatrix of {paths[0]}:'
458
                    f'\n{affine}'
459
                    f'\nMatrix of {path}:'
460
                    f'\n{new_affine}'
461
                )
462
                warnings.warn(message, RuntimeWarning)
463
            if not tensor.shape[1:] == new_tensor.shape[1:]:
464
                message = (
465
                    f'Files shape do not match, found {tensor.shape}'
466
                    f'and {new_tensor.shape}'
467
                )
468
                RuntimeError(message)
469
            tensors.append(new_tensor)
470
        tensor = torch.cat(tensors)
471
        self.set_data(tensor)
472
        self.affine = affine
473
        self._loaded = True
474
475
    def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
476
        tensor, affine = self.reader(path)
477
        tensor = self.parse_tensor_shape(tensor)
478
        tensor = self._parse_tensor(tensor)
479
        affine = self._parse_affine(affine)
480
        if self.channels_last:
481
            tensor = tensor.permute(3, 0, 1, 2)
482
        if self.check_nans and torch.isnan(tensor).any():
483
            warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
484
        return tensor, affine
485
486
    def save(self, path: TypePath, squeeze: bool = True) -> None:
487
        """Save image to disk.
488
489
        Args:
490
            path: String or instance of :class:`pathlib.Path`.
491
            squeeze: If ``True``, singleton dimensions will be removed
492
                before saving.
493
        """
494
        write_image(
495
            self.data,
496
            self.affine,
497
            path,
498
            squeeze=squeeze,
499
        )
500
501
    def is_2d(self) -> bool:
502
        return self.shape[-1] == 1
503
504
    def numpy(self) -> np.ndarray:
505
        """Get a NumPy array containing the image data."""
506
        return np.asarray(self)
507
508
    def as_sitk(self, **kwargs) -> sitk.Image:
509
        """Get the image as an instance of :class:`sitk.Image`."""
510
        return nib_to_sitk(self.data, self.affine, **kwargs)
511
512
    @classmethod
513
    def from_sitk(cls, sitk_image):
514
        """Instantiate a new TorchIO image from a :class:`sitk.Image`.
515
516
        Example:
517
            >>> import torchio as tio
518
            >>> import SimpleITK as sitk
519
            >>> sitk_image = sitk.Image(20, 30, 40, sitk.sitkUInt16)
520
            >>> tio.LabelMap.from_sitk(sitk_image)
521
            LabelMap(shape: (1, 20, 30, 40); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 93.8 KiB; dtype: torch.IntTensor)
522
            >>> sitk_image = sitk.Image((224, 224), sitk.sitkVectorFloat32, 3)
523
            >>> tio.ScalarImage.from_sitk(sitk_image)
524
            ScalarImage(shape: (3, 224, 224, 1); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 588.0 KiB; dtype: torch.FloatTensor)
525
        """
526
        tensor, affine = sitk_to_nib(sitk_image)
527
        return cls(tensor=tensor, affine=affine)
528
529
    def as_pil(self, transpose=True):
530
        """Get the image as an instance of :class:`PIL.Image`.
531
532
        .. note:: Values will be clamped to 0-255 and cast to uint8.
533
        .. note:: To use this method, `Pillow` needs to be installed:
534
            `pip install Pillow`.
535
        """
536
        try:
537
            from PIL import Image as ImagePIL
538
        except ModuleNotFoundError as e:
539
            message = (
540
                'Please install Pillow to use Image.as_pil():'
541
                ' pip install Pillow'
542
            )
543
            raise RuntimeError(message) from e
544
545
        self.check_is_2d()
546
        tensor = self.data
547
        if len(tensor) == 1:
548
            tensor = torch.cat(3 * [tensor])
549
        if len(tensor) != 3:
550
            raise RuntimeError('The image must have 1 or 3 channels')
551
        if transpose:
552
            tensor = tensor.permute(3, 2, 1, 0)
553
        else:
554
            tensor = tensor.permute(3, 1, 2, 0)
555
        array = tensor.clamp(0, 255).numpy()[0]
556
        return ImagePIL.fromarray(array.astype(np.uint8))
557
558
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
559
        """Get image center in RAS+ or LPS+ coordinates.
560
561
        Args:
562
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
563
                the first dimension grows towards the left, etc. Otherwise, the
564
                coordinates will be in RAS+ orientation.
565
        """
566
        size = np.array(self.spatial_shape)
567
        center_index = (size - 1) / 2
568
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
569
        if lps:
570
            return (-r, -a, s)
571
        else:
572
            return (r, a, s)
573
574
    def set_check_nans(self, check_nans: bool) -> None:
575
        self.check_nans = check_nans
576
577
    def plot(self, **kwargs) -> None:
578
        """Plot image."""
579
        if self.is_2d():
580
            self.as_pil().show()
581
        else:
582
            from ..visualization import plot_volume  # avoid circular import
583
            plot_volume(self, **kwargs)
584
585
586
class ScalarImage(Image):
587
    """Image whose pixel values represent scalars.
588
589
    Example:
590
        >>> import torch
591
        >>> import torchio as tio
592
        >>> # Loading from a file
593
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
594
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
595
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
596
        >>> data, affine = image.data, image.affine
597
        >>> affine.shape
598
        (4, 4)
599
        >>> image.data is image[tio.DATA]
600
        True
601
        >>> image.data is image.tensor
602
        True
603
        >>> type(image.data)
604
        torch.Tensor
605
606
    See :class:`~torchio.Image` for more information.
607
    """
608
    def __init__(self, *args, **kwargs):
609
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
610
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
611
        kwargs.update({'type': INTENSITY})
612
        super().__init__(*args, **kwargs)
613
614
615
class LabelMap(Image):
616
    """Image whose pixel values represent categorical labels.
617
618
    Example:
619
        >>> import torch
620
        >>> import torchio as tio
621
        >>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5)
622
        >>> labels = tio.LabelMap('t1_seg.nii.gz')  # loading from a file
623
        >>> tpm = tio.LabelMap(                     # loading from files
624
        ...     'gray_matter.nii.gz',
625
        ...     'white_matter.nii.gz',
626
        ...     'csf.nii.gz',
627
        ... )
628
629
    Intensity transforms are not applied to these images.
630
631
    Nearest neighbor interpolation is always used to resample label maps,
632
    independently of the specified interpolation type in the transform
633
    instantiation.
634
635
    See :class:`~torchio.Image` for more information.
636
    """
637
    def __init__(self, *args, **kwargs):
638
        if 'type' in kwargs and kwargs['type'] != LABEL:
639
            raise ValueError('Type of LabelMap is always torchio.LABEL')
640
        kwargs.update({'type': LABEL})
641
        super().__init__(*args, **kwargs)
642