Passed
Push — master ( 5ab5d0...c41f84 )
by Fernando
01:14
created

torchio.data.image.Image.from_sitk()   A

Complexity

Conditions 1

Size

Total Lines 16
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 16
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
import warnings
2
from pathlib import Path
3
from collections.abc import Iterable
4
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List, Callable
5
6
import torch
7
import humanize
8
import numpy as np
9
import nibabel as nib
10
import SimpleITK as sitk
11
from deprecated import deprecated
12
13
from ..utils import get_stem
14
from ..typing import (
15
    TypeData,
16
    TypePath,
17
    TypeTripletInt,
18
    TypeTripletFloat,
19
    TypeDirection3D,
20
)
21
from ..constants import DATA, TYPE, AFFINE, PATH, STEM, INTENSITY, LABEL
22
from .io import (
23
    ensure_4d,
24
    read_image,
25
    write_image,
26
    nib_to_sitk,
27
    sitk_to_nib,
28
    check_uint_to_int,
29
    get_rotation_and_spacing_from_affine,
30
    get_sitk_metadata_from_ras_affine,
31
    read_shape,
32
    read_affine,
33
)
34
35
36
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
37
TypeBound = Tuple[float, float]
38
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
39
40
deprecation_message = (
41
    'Setting the image data with the property setter is deprecated. Use the'
42
    ' set_data() method instead'
43
)
44
45
46
class Image(dict):
47
    r"""TorchIO image.
48
49
    For information about medical image orientation, check out `NiBabel docs`_,
50
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
51
    `SimpleITK docs`_.
52
53
    Args:
54
        path: Path to a file or sequence of paths to files that can be read by
55
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
56
            DICOM files. If :attr:`tensor` is given, the data in
57
            :attr:`path` will not be read.
58
            If a sequence of paths is given, data
59
            will be concatenated on the channel dimension so spatial
60
            dimensions must match.
61
        type: Type of image, such as :attr:`torchio.INTENSITY` or
62
            :attr:`torchio.LABEL`. This will be used by the transforms to
63
            decide whether to apply an operation, or which interpolation to use
64
            when resampling. For example, `preprocessing`_ and `augmentation`_
65
            intensity transforms will only be applied to images with type
66
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
67
            all types, and nearest neighbor interpolation is always used to
68
            resample images with type :attr:`torchio.LABEL`.
69
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
70
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
71
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
72
            :class:`torch.Tensor` or NumPy array with dimensions
73
            :math:`(C, W, H, D)`.
74
        affine: :math:`4 \times 4` matrix to convert voxel coordinates to world
75
            coordinates. If ``None``, an identity matrix will be used. See the
76
            `NiBabel docs on coordinates`_ for more information.
77
        check_nans: If ``True``, issues a warning if NaNs are found
78
            in the image. If ``False``, images will not be checked for the
79
            presence of NaNs.
80
        channels_last: If ``True``, the read tensor will be permuted so the
81
            last dimension becomes the first. This is useful, e.g., when
82
            NIfTI images have been saved with the channels dimension being the
83
            fourth instead of the fifth.
84
        reader: Callable object that takes a path and returns a 4D tensor and a
85
            2D, :math:`4 \times 4` affine matrix. This can be used if your data
86
            is saved in a custom format, such as ``.npy`` (see example below).
87
            If the affine matrix is ``None``, an identity matrix will be used.
88
        **kwargs: Items that will be added to the image dictionary, e.g.
89
            acquisition parameters.
90
91
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
92
    when needed.
93
94
    Example:
95
        >>> import torchio as tio
96
        >>> import numpy as np
97
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
98
        >>> image  # not loaded yet
99
        ScalarImage(path: t1.nii.gz; type: intensity)
100
        >>> times_two = 2 * image.data  # data is loaded and cached here
101
        >>> image
102
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
103
        >>> image.save('doubled_image.nii.gz')
104
        >>> numpy_reader = lambda path: np.load(path), np.eye(4)
105
        >>> image = tio.ScalarImage('t1.npy', reader=numpy_reader)
106
107
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
108
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
109
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
110
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
111
    .. _NiBabel docs on coordinates: https://nipy.org/nibabel/coordinate_systems.html#the-affine-matrix-as-a-transformation-between-spaces
112
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
113
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
114
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
115
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
116
    """
117
    def __init__(
118
            self,
119
            path: Union[TypePath, Sequence[TypePath], None] = None,
120
            type: str = None,
121
            tensor: Optional[TypeData] = None,
122
            affine: Optional[TypeData] = None,
123
            check_nans: bool = False,  # removed by ITK by default
124
            channels_last: bool = False,
125
            reader: Callable = read_image,
126
            **kwargs: Dict[str, Any],
127
            ):
128
        self.check_nans = check_nans
129
        self.channels_last = channels_last
130
        self.reader = reader
131
132
        if type is None:
133
            warnings.warn(
134
                'Not specifying the image type is deprecated and will be'
135
                ' mandatory in the future. You can probably use tio.ScalarImage'
136
                ' or tio.LabelMap instead',
137
            )
138
            type = INTENSITY
139
140
        if path is None and tensor is None:
141
            raise ValueError('A value for path or tensor must be given')
142
        self._loaded = False
143
144
        tensor = self._parse_tensor(tensor)
145
        affine = self._parse_affine(affine)
146
        if tensor is not None:
147
            self.set_data(tensor)
148
            self.affine = affine
149
            self._loaded = True
150
        for key in PROTECTED_KEYS:
151
            if key in kwargs:
152
                message = f'Key "{key}" is reserved. Use a different one'
153
                raise ValueError(message)
154
155
        super().__init__(**kwargs)
156
        self.path = self._parse_path(path)
157
158
        self[PATH] = '' if self.path is None else str(self.path)
159
        self[STEM] = '' if self.path is None else get_stem(self.path)
160
        self[TYPE] = type
161
162
    def __repr__(self):
163
        properties = []
164
        properties.extend([
165
            f'shape: {self.shape}',
166
            f'spacing: {self.get_spacing_string()}',
167
            f'orientation: {"".join(self.orientation)}+',
168
        ])
169
        if self._loaded:
170
            properties.append(f'dtype: {self.data.type()}')
171
            properties.append(f'memory: {humanize.naturalsize(self.memory, binary=True)}')
172
        else:
173
            properties.append(f'path: "{self.path}"')
174
175
        properties = '; '.join(properties)
176
        string = f'{self.__class__.__name__}({properties})'
177
        return string
178
179
    def __getitem__(self, item):
180
        if item in (DATA, AFFINE):
181
            if item not in self:
182
                self.load()
183
        return super().__getitem__(item)
184
185
    def __array__(self):
186
        return self.data.numpy()
187
188
    def __copy__(self):
189
        kwargs = dict(
190
            tensor=self.data,
191
            affine=self.affine,
192
            type=self.type,
193
            path=self.path,
194
        )
195
        for key, value in self.items():
196
            if key in PROTECTED_KEYS: continue
197
            kwargs[key] = value  # should I copy? deepcopy?
198
        return self.__class__(**kwargs)
199
200
    @property
201
    def data(self) -> torch.Tensor:
202
        """Tensor data. Same as :class:`Image.tensor`."""
203
        return self[DATA]
204
205
    @data.setter
206
    @deprecated(version='0.18.16', reason=deprecation_message)
207
    def data(self, tensor: TypeData):
208
        self.set_data(tensor)
209
210
    def set_data(self, tensor: TypeData):
211
        """Store a 4D tensor in the :attr:`data` key and attribute.
212
213
        Args:
214
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
215
        """
216
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
217
218
    @property
219
    def tensor(self) -> torch.Tensor:
220
        """Tensor data. Same as :class:`Image.data`."""
221
        return self.data
222
223
    @property
224
    def affine(self) -> np.ndarray:
225
        """Affine matrix to transform voxel indices into world coordinates."""
226
        # If path is a dir (probably DICOM), just load the data
227
        if self._loaded or self.path.is_dir():
228
            affine = self[AFFINE]
229
        else:
230
            affine = read_affine(self.path)
231
        return affine
232
233
    @affine.setter
234
    def affine(self, matrix):
235
        self[AFFINE] = self._parse_affine(matrix)
236
237
    @property
238
    def type(self) -> str:
239
        return self[TYPE]
240
241
    @property
242
    def shape(self) -> Tuple[int, int, int, int]:
243
        """Tensor shape as :math:`(C, W, H, D)`."""
244
        custom_reader = self.reader is not read_image
245
        multipath = not isinstance(self.path, (str, Path))
246
        if self._loaded or custom_reader or multipath or self.path.is_dir():
247
            shape = tuple(self.data.shape)
248
        else:
249
            shape = read_shape(self.path)
250
        return shape
251
252
    @property
253
    def spatial_shape(self) -> TypeTripletInt:
254
        """Tensor spatial shape as :math:`(W, H, D)`."""
255
        return self.shape[1:]
256
257
    def check_is_2d(self) -> None:
258
        if not self.is_2d():
259
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
260
            raise RuntimeError(message)
261
262
    @property
263
    def height(self) -> int:
264
        """Image height, if 2D."""
265
        self.check_is_2d()
266
        return self.spatial_shape[1]
267
268
    @property
269
    def width(self) -> int:
270
        """Image width, if 2D."""
271
        self.check_is_2d()
272
        return self.spatial_shape[0]
273
274
    @property
275
    def orientation(self) -> Tuple[str, str, str]:
276
        """Orientation codes."""
277
        return nib.aff2axcodes(self.affine)
278
279
    @property
280
    def direction(self) -> TypeDirection3D:
281
        _, _, direction = get_sitk_metadata_from_ras_affine(
282
            self.affine, lps=False)
283
        return direction
284
285
    @property
286
    def spacing(self) -> Tuple[float, float, float]:
287
        """Voxel spacing in mm."""
288
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
289
        return tuple(spacing)
290
291
    @property
292
    def origin(self) -> Tuple[float, float, float]:
293
        """Center of first voxel in array, in mm."""
294
        return tuple(self.affine[:3, 3])
295
296
    @property
297
    def itemsize(self):
298
        """Element size of the data type."""
299
        return self.data.element_size()
300
301
    @property
302
    def memory(self) -> float:
303
        """Number of Bytes that the tensor takes in the RAM."""
304
        return np.prod(self.shape) * self.itemsize
305
306
    @property
307
    def bounds(self) -> np.ndarray:
308
        """Position of centers of voxels in smallest and largest coordinates."""
309
        ini = 0, 0, 0
310
        fin = np.array(self.spatial_shape) - 1
311
        point_ini = nib.affines.apply_affine(self.affine, ini)
312
        point_fin = nib.affines.apply_affine(self.affine, fin)
313
        return np.array((point_ini, point_fin))
314
315
    @property
316
    def num_channels(self) -> int:
317
        """Get the number of channels in the associated 4D tensor."""
318
        return len(self.data)
319
320
    def axis_name_to_index(self, axis: str) -> int:
321
        """Convert an axis name to an axis index.
322
323
        Args:
324
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
325
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
326
                versions and first letters are also valid, as only the first
327
                letter will be used.
328
329
        .. note:: If you are working with animals, you should probably use
330
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
331
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
332
            respectively.
333
334
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
335
            ``'Left'`` and ``'Right'``.
336
        """
337
        # Top and bottom are used for the vertical 2D axis as the use of
338
        # Height vs Horizontal might be ambiguous
339
340
        if not isinstance(axis, str):
341
            raise ValueError('Axis must be a string')
342
        axis = axis[0].upper()
343
344
        # Generally, TorchIO tensors are (C, W, H, D)
345
        if axis in 'TB':  # Top, Bottom
346
            return -2
347
        else:
348
            try:
349
                index = self.orientation.index(axis)
350
            except ValueError:
351
                index = self.orientation.index(self.flip_axis(axis))
352
            # Return negative indices so that it does not matter whether we
353
            # refer to spatial dimensions or not
354
            index = -3 + index
355
            return index
356
357
    # flake8: noqa: E701
358
    @staticmethod
359
    def flip_axis(axis: str) -> str:
360
        if axis == 'R': flipped_axis = 'L'
361
        elif axis == 'L': flipped_axis = 'R'
362
        elif axis == 'A': flipped_axis = 'P'
363
        elif axis == 'P': flipped_axis = 'A'
364
        elif axis == 'I': flipped_axis = 'S'
365
        elif axis == 'S': flipped_axis = 'I'
366
        elif axis == 'T': flipped_axis = 'B'  # top / bottom
367
        elif axis == 'B': flipped_axis = 'T'
368
        else:
369
            values = ', '.join('LRPAISTB')
370
            message = f'Axis not understood. Please use one of: {values}'
371
            raise ValueError(message)
372
        return flipped_axis
373
374
    def get_spacing_string(self) -> str:
375
        strings = [f'{n:.2f}' for n in self.spacing]
376
        string = f'({", ".join(strings)})'
377
        return string
378
379
    def get_bounds(self) -> TypeBounds:
380
        """Get minimum and maximum world coordinates occupied by the image."""
381
        first_index = 3 * (-0.5,)
382
        last_index = np.array(self.spatial_shape) - 0.5
383
        first_point = nib.affines.apply_affine(self.affine, first_index)
384
        last_point = nib.affines.apply_affine(self.affine, last_index)
385
        array = np.array((first_point, last_point))
386
        bounds_x, bounds_y, bounds_z = array.T.tolist()
387
        return bounds_x, bounds_y, bounds_z
388
389
    @staticmethod
390
    def _parse_single_path(
391
            path: TypePath
392
            ) -> Path:
393
        try:
394
            path = Path(path).expanduser()
395
        except TypeError:
396
            message = (
397
                f'Expected type str or Path but found {path} with type'
398
                f' {type(path)} instead'
399
            )
400
            raise TypeError(message)
401
        except RuntimeError:
402
            message = (
403
                f'Conversion to path not possible for variable: {path}'
404
            )
405
            raise RuntimeError(message)
406
407
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
408
            raise FileNotFoundError(f'File not found: "{path}"')
409
        return path
410
411
    def _parse_path(
412
            self,
413
            path: Union[TypePath, Sequence[TypePath]]
414
            ) -> Optional[Union[Path, List[Path]]]:
415
        if path is None:
416
            return None
417
        if isinstance(path, Iterable) and not isinstance(path, str):
418
            return [self._parse_single_path(p) for p in path]
419
        else:
420
            return self._parse_single_path(path)
421
422
    def _parse_tensor(
423
            self,
424
            tensor: TypeData,
425
            none_ok: bool = True,
426
            ) -> torch.Tensor:
427
        if tensor is None:
428
            if none_ok:
429
                return None
430
            else:
431
                raise RuntimeError('Input tensor cannot be None')
432
        if isinstance(tensor, np.ndarray):
433
            tensor = check_uint_to_int(tensor)
434
            tensor = torch.as_tensor(tensor)
435
        elif not isinstance(tensor, torch.Tensor):
436
            message = (
437
                'Input tensor must be a PyTorch tensor or NumPy array,'
438
                f' but type "{type(tensor)}" was found'
439
            )
440
            raise TypeError(message)
441
        ndim = tensor.ndim
442
        if ndim != 4:
443
            raise ValueError(f'Input tensor must be 4D, but it is {ndim}D')
444
        if tensor.dtype == torch.bool:
445
            tensor = tensor.to(torch.uint8)
446
        if self.check_nans and torch.isnan(tensor).any():
447
            warnings.warn(f'NaNs found in tensor', RuntimeWarning)
448
        return tensor
449
450
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
451
        return ensure_4d(tensor)
452
453
    @staticmethod
454
    def _parse_affine(affine: TypeData) -> np.ndarray:
455
        if affine is None:
456
            return np.eye(4)
457
        if isinstance(affine, torch.Tensor):
458
            affine = affine.numpy()
459
        if not isinstance(affine, np.ndarray):
460
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
461
        if affine.shape != (4, 4):
462
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
463
        return affine.astype(np.float64)
464
465
    def load(self) -> None:
466
        r"""Load the image from disk.
467
468
        Returns:
469
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
470
            :math:`4 \times 4` affine matrix to convert voxel indices to world
471
            coordinates.
472
        """
473
        if self._loaded:
474
            return
475
        paths = self.path if isinstance(self.path, list) else [self.path]
476
        tensor, affine = self.read_and_check(paths[0])
477
        tensors = [tensor]
478
        for path in paths[1:]:
479
            new_tensor, new_affine = self.read_and_check(path)
480
            if not np.array_equal(affine, new_affine):
481
                message = (
482
                    'Files have different affine matrices.'
483
                    f'\nMatrix of {paths[0]}:'
484
                    f'\n{affine}'
485
                    f'\nMatrix of {path}:'
486
                    f'\n{new_affine}'
487
                )
488
                warnings.warn(message, RuntimeWarning)
489
            if not tensor.shape[1:] == new_tensor.shape[1:]:
490
                message = (
491
                    f'Files shape do not match, found {tensor.shape}'
492
                    f'and {new_tensor.shape}'
493
                )
494
                RuntimeError(message)
495
            tensors.append(new_tensor)
496
        tensor = torch.cat(tensors)
497
        self.set_data(tensor)
498
        self.affine = affine
499
        self._loaded = True
500
501
    def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
502
        tensor, affine = self.reader(path)
503
        tensor = self.parse_tensor_shape(tensor)
504
        tensor = self._parse_tensor(tensor)
505
        affine = self._parse_affine(affine)
506
        if self.channels_last:
507
            tensor = tensor.permute(3, 0, 1, 2)
508
        if self.check_nans and torch.isnan(tensor).any():
509
            warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
510
        return tensor, affine
511
512
    def save(self, path: TypePath, squeeze: Optional[bool] = None) -> None:
513
        """Save image to disk.
514
515
        Args:
516
            path: String or instance of :class:`pathlib.Path`.
517
            squeeze: Whether to remove singleton dimensions before saving.
518
                If ``None``, the array will be squeezed if the output format is
519
                JP(E)G, PNG, BMP or TIF(F).
520
        """
521
        write_image(
522
            self.data,
523
            self.affine,
524
            path,
525
            squeeze=squeeze,
526
        )
527
528
    def is_2d(self) -> bool:
529
        return self.shape[-1] == 1
530
531
    def numpy(self) -> np.ndarray:
532
        """Get a NumPy array containing the image data."""
533
        return np.asarray(self)
534
535
    def as_sitk(self, **kwargs) -> sitk.Image:
536
        """Get the image as an instance of :class:`sitk.Image`."""
537
        return nib_to_sitk(self.data, self.affine, **kwargs)
538
539
    @classmethod
540
    def from_sitk(cls, sitk_image):
541
        """Instantiate a new TorchIO image from a :class:`sitk.Image`.
542
543
        Example:
544
            >>> import torchio as tio
545
            >>> import SimpleITK as sitk
546
            >>> sitk_image = sitk.Image(20, 30, 40, sitk.sitkUInt16)
547
            >>> tio.LabelMap.from_sitk(sitk_image)
548
            LabelMap(shape: (1, 20, 30, 40); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 93.8 KiB; dtype: torch.IntTensor)
549
            >>> sitk_image = sitk.Image((224, 224), sitk.sitkVectorFloat32, 3)
550
            >>> tio.ScalarImage.from_sitk(sitk_image)
551
            ScalarImage(shape: (3, 224, 224, 1); spacing: (1.00, 1.00, 1.00); orientation: LPS+; memory: 588.0 KiB; dtype: torch.FloatTensor)
552
        """
553
        tensor, affine = sitk_to_nib(sitk_image)
554
        return cls(tensor=tensor, affine=affine)
555
556
    def as_pil(self, transpose=True):
557
        """Get the image as an instance of :class:`PIL.Image`.
558
559
        .. note:: Values will be clamped to 0-255 and cast to uint8.
560
        .. note:: To use this method, `Pillow` needs to be installed:
561
            `pip install Pillow`.
562
        """
563
        try:
564
            from PIL import Image as ImagePIL
565
        except ModuleNotFoundError as e:
566
            message = (
567
                'Please install Pillow to use Image.as_pil():'
568
                ' pip install Pillow'
569
            )
570
            raise RuntimeError(message) from e
571
572
        self.check_is_2d()
573
        tensor = self.data
574
        if len(tensor) == 1:
575
            tensor = torch.cat(3 * [tensor])
576
        if len(tensor) != 3:
577
            raise RuntimeError('The image must have 1 or 3 channels')
578
        if transpose:
579
            tensor = tensor.permute(3, 2, 1, 0)
580
        else:
581
            tensor = tensor.permute(3, 1, 2, 0)
582
        array = tensor.clamp(0, 255).numpy()[0]
583
        return ImagePIL.fromarray(array.astype(np.uint8))
584
585
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
586
        """Get image center in RAS+ or LPS+ coordinates.
587
588
        Args:
589
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
590
                the first dimension grows towards the left, etc. Otherwise, the
591
                coordinates will be in RAS+ orientation.
592
        """
593
        size = np.array(self.spatial_shape)
594
        center_index = (size - 1) / 2
595
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
596
        if lps:
597
            return (-r, -a, s)
598
        else:
599
            return (r, a, s)
600
601
    def set_check_nans(self, check_nans: bool) -> None:
602
        self.check_nans = check_nans
603
604
    def plot(self, **kwargs) -> None:
605
        """Plot image."""
606
        if self.is_2d():
607
            self.as_pil().show()
608
        else:
609
            from ..visualization import plot_volume  # avoid circular import
610
            plot_volume(self, **kwargs)
611
612
613
class ScalarImage(Image):
614
    """Image whose pixel values represent scalars.
615
616
    Example:
617
        >>> import torch
618
        >>> import torchio as tio
619
        >>> # Loading from a file
620
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
621
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
622
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
623
        >>> data, affine = image.data, image.affine
624
        >>> affine.shape
625
        (4, 4)
626
        >>> image.data is image[tio.DATA]
627
        True
628
        >>> image.data is image.tensor
629
        True
630
        >>> type(image.data)
631
        torch.Tensor
632
633
    See :class:`~torchio.Image` for more information.
634
    """
635
    def __init__(self, *args, **kwargs):
636
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
637
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
638
        kwargs.update({'type': INTENSITY})
639
        super().__init__(*args, **kwargs)
640
641
642
class LabelMap(Image):
643
    """Image whose pixel values represent categorical labels.
644
645
    Example:
646
        >>> import torch
647
        >>> import torchio as tio
648
        >>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5)
649
        >>> labels = tio.LabelMap('t1_seg.nii.gz')  # loading from a file
650
        >>> tpm = tio.LabelMap(                     # loading from files
651
        ...     'gray_matter.nii.gz',
652
        ...     'white_matter.nii.gz',
653
        ...     'csf.nii.gz',
654
        ... )
655
656
    Intensity transforms are not applied to these images.
657
658
    Nearest neighbor interpolation is always used to resample label maps,
659
    independently of the specified interpolation type in the transform
660
    instantiation.
661
662
    See :class:`~torchio.Image` for more information.
663
    """
664
    def __init__(self, *args, **kwargs):
665
        if 'type' in kwargs and kwargs['type'] != LABEL:
666
            raise ValueError('Type of LabelMap is always torchio.LABEL')
667
        kwargs.update({'type': LABEL})
668
        super().__init__(*args, **kwargs)
669