Passed
Push — master ( e59e2c...7d4e6f )
by Fernando
01:21
created

torchio.data.image.Image.num_channels()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List, Callable
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
from deprecated import deprecated
11
12
from ..utils import get_stem
13
from ..typing import TypeData, TypePath, TypeTripletInt, TypeTripletFloat
14
from ..constants import DATA, TYPE, AFFINE, PATH, STEM, INTENSITY, LABEL
15
from .io import (
16
    ensure_4d,
17
    read_image,
18
    write_image,
19
    nib_to_sitk,
20
    check_uint_to_int,
21
    get_rotation_and_spacing_from_affine,
22
)
23
24
25
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
26
TypeBound = Tuple[float, float]
27
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
28
29
deprecation_message = (
30
    'Setting the image data with the property setter is deprecated. Use the'
31
    ' set_data() method instead'
32
)
33
34
35
class Image(dict):
36
    r"""TorchIO image.
37
38
    For information about medical image orientation, check out `NiBabel docs`_,
39
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
40
    `SimpleITK docs`_.
41
42
    Args:
43
        path: Path to a file or sequence of paths to files that can be read by
44
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
45
            DICOM files. If :attr:`tensor` is given, the data in
46
            :attr:`path` will not be read.
47
            If a sequence of paths is given, data
48
            will be concatenated on the channel dimension so spatial
49
            dimensions must match.
50
        type: Type of image, such as :attr:`torchio.INTENSITY` or
51
            :attr:`torchio.LABEL`. This will be used by the transforms to
52
            decide whether to apply an operation, or which interpolation to use
53
            when resampling. For example, `preprocessing`_ and `augmentation`_
54
            intensity transforms will only be applied to images with type
55
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
56
            all types, and nearest neighbor interpolation is always used to
57
            resample images with type :attr:`torchio.LABEL`.
58
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
59
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
60
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
61
            :class:`torch.Tensor` or NumPy array with dimensions
62
            :math:`(C, W, H, D)`.
63
        affine: If :attr:`path` is not given, :attr:`affine` must be a
64
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
65
            identity matrix.
66
        check_nans: If ``True``, issues a warning if NaNs are found
67
            in the image. If ``False``, images will not be checked for the
68
            presence of NaNs.
69
        reader: Callable object that takes a path and returns a 4D tensor and a
70
            2D, :math:`4 \times 4` affine matrix. This can be used if your data
71
            is saved in a custom format, such as ``.npy`` (see example below).
72
            If the affine matrix is ``None``, an identity matrix will be used.
73
        **kwargs: Items that will be added to the image dictionary, e.g.
74
            acquisition parameters.
75
76
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
77
    when needed.
78
79
    Example:
80
        >>> import torchio as tio
81
        >>> import numpy as np
82
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
83
        >>> image  # not loaded yet
84
        ScalarImage(path: t1.nii.gz; type: intensity)
85
        >>> times_two = 2 * image.data  # data is loaded and cached here
86
        >>> image
87
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
88
        >>> image.save('doubled_image.nii.gz')
89
        >>> numpy_reader = lambda path: np.load(path), np.eye(4)
90
        >>> image = tio.ScalarImage('t1.npy', reader=numpy_reader)
91
92
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
93
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
94
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
95
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
96
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
97
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
98
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
99
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
100
    """
101
    def __init__(
102
            self,
103
            path: Union[TypePath, Sequence[TypePath], None] = None,
104
            type: str = None,
105
            tensor: Optional[TypeData] = None,
106
            affine: Optional[TypeData] = None,
107
            check_nans: bool = False,  # removed by ITK by default
108
            channels_last: bool = False,
109
            reader: Callable = read_image,
110
            **kwargs: Dict[str, Any],
111
            ):
112
        self.check_nans = check_nans
113
        self.channels_last = channels_last
114
        self.reader = reader
115
116
        if type is None:
117
            warnings.warn(
118
                'Not specifying the image type is deprecated and will be'
119
                ' mandatory in the future. You can probably use tio.ScalarImage'
120
                ' or tio.LabelMap instead',
121
            )
122
            type = INTENSITY
123
124
        if path is None and tensor is None:
125
            raise ValueError('A value for path or tensor must be given')
126
        self._loaded = False
127
128
        tensor = self._parse_tensor(tensor)
129
        affine = self._parse_affine(affine)
130
        if tensor is not None:
131
            self.set_data(tensor)
132
            self.affine = affine
133
            self._loaded = True
134
        for key in PROTECTED_KEYS:
135
            if key in kwargs:
136
                message = f'Key "{key}" is reserved. Use a different one'
137
                raise ValueError(message)
138
139
        super().__init__(**kwargs)
140
        self.path = self._parse_path(path)
141
142
        self[PATH] = '' if self.path is None else str(self.path)
143
        self[STEM] = '' if self.path is None else get_stem(self.path)
144
        self[TYPE] = type
145
146
    def __repr__(self):
147
        properties = []
148
        if self._loaded:
149
            properties.extend([
150
                f'shape: {self.shape}',
151
                f'spacing: {self.get_spacing_string()}',
152
                f'orientation: {"".join(self.orientation)}+',
153
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
154
            ])
155
        else:
156
            properties.append(f'path: "{self.path}"')
157
        if self._loaded:
158
            properties.append(f'dtype: {self.data.type()}')
159
        properties = '; '.join(properties)
160
        string = f'{self.__class__.__name__}({properties})'
161
        return string
162
163
    def __getitem__(self, item):
164
        if item in (DATA, AFFINE):
165
            if item not in self:
166
                self.load()
167
        return super().__getitem__(item)
168
169
    def __array__(self):
170
        return self.data.numpy()
171
172
    def __copy__(self):
173
        kwargs = dict(
174
            tensor=self.data,
175
            affine=self.affine,
176
            type=self.type,
177
            path=self.path,
178
        )
179
        for key, value in self.items():
180
            if key in PROTECTED_KEYS: continue
181
            kwargs[key] = value  # should I copy? deepcopy?
182
        return self.__class__(**kwargs)
183
184
    @property
185
    def data(self) -> torch.Tensor:
186
        """Tensor data. Same as :class:`Image.tensor`."""
187
        return self[DATA]
188
189
    @data.setter
190
    @deprecated(version='0.18.16', reason=deprecation_message)
191
    def data(self, tensor: TypeData):
192
        self.set_data(tensor)
193
194
    def set_data(self, tensor: TypeData):
195
        """Store a 4D tensor in the :attr:`data` key and attribute.
196
197
        Args:
198
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
199
        """
200
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
201
202
    @property
203
    def tensor(self) -> torch.Tensor:
204
        """Tensor data. Same as :class:`Image.data`."""
205
        return self.data
206
207
    @property
208
    def affine(self) -> np.ndarray:
209
        """Affine matrix to transform voxel indices into world coordinates."""
210
        return self[AFFINE]
211
212
    @affine.setter
213
    def affine(self, matrix):
214
        self[AFFINE] = self._parse_affine(matrix)
215
216
    @property
217
    def type(self) -> str:
218
        return self[TYPE]
219
220
    @property
221
    def shape(self) -> Tuple[int, int, int, int]:
222
        """Tensor shape as :math:`(C, W, H, D)`."""
223
        return tuple(self.data.shape)
224
225
    @property
226
    def spatial_shape(self) -> TypeTripletInt:
227
        """Tensor spatial shape as :math:`(W, H, D)`."""
228
        return self.shape[1:]
229
230
    def check_is_2d(self) -> None:
231
        if not self.is_2d():
232
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
233
            raise RuntimeError(message)
234
235
    @property
236
    def height(self) -> int:
237
        """Image height, if 2D."""
238
        self.check_is_2d()
239
        return self.spatial_shape[1]
240
241
    @property
242
    def width(self) -> int:
243
        """Image width, if 2D."""
244
        self.check_is_2d()
245
        return self.spatial_shape[0]
246
247
    @property
248
    def orientation(self) -> Tuple[str, str, str]:
249
        """Orientation codes."""
250
        return nib.aff2axcodes(self.affine)
251
252
    @property
253
    def spacing(self) -> Tuple[float, float, float]:
254
        """Voxel spacing in mm."""
255
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
256
        return tuple(spacing)
257
258
    @property
259
    def itemsize(self):
260
        """Element size of the data type."""
261
        return self.data.element_size()
262
263
    @property
264
    def memory(self) -> float:
265
        """Number of Bytes that the tensor takes in the RAM."""
266
        return np.prod(self.shape) * self.itemsize
267
268
    @property
269
    def bounds(self) -> np.ndarray:
270
        """Position of centers of voxels in smallest and largest coordinates."""
271
        ini = 0, 0, 0
272
        fin = np.array(self.spatial_shape) - 1
273
        point_ini = nib.affines.apply_affine(self.affine, ini)
274
        point_fin = nib.affines.apply_affine(self.affine, fin)
275
        return np.array((point_ini, point_fin))
276
277
    @property
278
    def num_channels(self) -> int:
279
        """Get the number of channels in the associated 4D tensor."""
280
        return len(self.data)
281
282
    def axis_name_to_index(self, axis: str) -> int:
283
        """Convert an axis name to an axis index.
284
285
        Args:
286
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
287
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
288
                versions and first letters are also valid, as only the first
289
                letter will be used.
290
291
        .. note:: If you are working with animals, you should probably use
292
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
293
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
294
            respectively.
295
296
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
297
            ``'Left'`` and ``'Right'``.
298
        """
299
        # Top and bottom are used for the vertical 2D axis as the use of
300
        # Height vs Horizontal might be ambiguous
301
302
        if not isinstance(axis, str):
303
            raise ValueError('Axis must be a string')
304
        axis = axis[0].upper()
305
306
        # Generally, TorchIO tensors are (C, W, H, D)
307
        if axis in 'TB':  # Top, Bottom
308
            return -2
309
        else:
310
            try:
311
                index = self.orientation.index(axis)
312
            except ValueError:
313
                index = self.orientation.index(self.flip_axis(axis))
314
            # Return negative indices so that it does not matter whether we
315
            # refer to spatial dimensions or not
316
            index = -3 + index
317
            return index
318
319
    # flake8: noqa: E701
320
    @staticmethod
321
    def flip_axis(axis: str) -> str:
322
        if axis == 'R': flipped_axis = 'L'
323
        elif axis == 'L': flipped_axis = 'R'
324
        elif axis == 'A': flipped_axis = 'P'
325
        elif axis == 'P': flipped_axis = 'A'
326
        elif axis == 'I': flipped_axis = 'S'
327
        elif axis == 'S': flipped_axis = 'I'
328
        elif axis == 'T': flipped_axis = 'B'
329
        elif axis == 'B': flipped_axis = 'T'
330
        else:
331
            values = ', '.join('LRPAISTB')
332
            message = f'Axis not understood. Please use one of: {values}'
333
            raise ValueError(message)
334
        return flipped_axis
335
336
    def get_spacing_string(self) -> str:
337
        strings = [f'{n:.2f}' for n in self.spacing]
338
        string = f'({", ".join(strings)})'
339
        return string
340
341
    def get_bounds(self) -> TypeBounds:
342
        """Get minimum and maximum world coordinates occupied by the image."""
343
        first_index = 3 * (-0.5,)
344
        last_index = np.array(self.spatial_shape) - 0.5
345
        first_point = nib.affines.apply_affine(self.affine, first_index)
346
        last_point = nib.affines.apply_affine(self.affine, last_index)
347
        array = np.array((first_point, last_point))
348
        bounds_x, bounds_y, bounds_z = array.T.tolist()
349
        return bounds_x, bounds_y, bounds_z
350
351
    @staticmethod
352
    def _parse_single_path(
353
            path: TypePath
354
            ) -> Path:
355
        try:
356
            path = Path(path).expanduser()
357
        except TypeError:
358
            message = (
359
                f'Expected type str or Path but found {path} with type'
360
                f' {type(path)} instead'
361
            )
362
            raise TypeError(message)
363
        except RuntimeError:
364
            message = (
365
                f'Conversion to path not possible for variable: {path}'
366
            )
367
            raise RuntimeError(message)
368
369
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
370
            raise FileNotFoundError(f'File not found: "{path}"')
371
        return path
372
373
    def _parse_path(
374
            self,
375
            path: Union[TypePath, Sequence[TypePath]]
376
            ) -> Optional[Union[Path, List[Path]]]:
377
        if path is None:
378
            return None
379
        if isinstance(path, (str, Path)):
380
            return self._parse_single_path(path)
381
        else:
382
            return [self._parse_single_path(p) for p in path]
383
384
    def _parse_tensor(
385
            self,
386
            tensor: TypeData,
387
            none_ok: bool = True,
388
            ) -> torch.Tensor:
389
        if tensor is None:
390
            if none_ok:
391
                return None
392
            else:
393
                raise RuntimeError('Input tensor cannot be None')
394
        if isinstance(tensor, np.ndarray):
395
            tensor = check_uint_to_int(tensor)
396
            tensor = torch.as_tensor(tensor)
397
        elif not isinstance(tensor, torch.Tensor):
398
            message = 'Input tensor must be a PyTorch tensor or NumPy array'
399
            raise TypeError(message)
400
        ndim = tensor.ndim
401
        if ndim != 4:
402
            raise ValueError(f'Input tensor must be 4D, but it is {ndim}D')
403
        if tensor.dtype == torch.bool:
404
            tensor = tensor.to(torch.uint8)
405
        if self.check_nans and torch.isnan(tensor).any():
406
            warnings.warn(f'NaNs found in tensor', RuntimeWarning)
407
        return tensor
408
409
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
410
        return ensure_4d(tensor)
411
412
    @staticmethod
413
    def _parse_affine(affine: TypeData) -> np.ndarray:
414
        if affine is None:
415
            return np.eye(4)
416
        if isinstance(affine, torch.Tensor):
417
            affine = affine.numpy()
418
        if not isinstance(affine, np.ndarray):
419
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
420
        if affine.shape != (4, 4):
421
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
422
        return affine
423
424
    def load(self) -> None:
425
        r"""Load the image from disk.
426
427
        Returns:
428
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
429
            :math:`4 \times 4` affine matrix to convert voxel indices to world
430
            coordinates.
431
        """
432
        if self._loaded:
433
            return
434
        paths = self.path if isinstance(self.path, list) else [self.path]
435
        tensor, affine = self.read_and_check(paths[0])
436
        tensors = [tensor]
437
        for path in paths[1:]:
438
            new_tensor, new_affine = self.read_and_check(path)
439
            if not np.array_equal(affine, new_affine):
440
                message = (
441
                    'Files have different affine matrices.'
442
                    f'\nMatrix of {paths[0]}:'
443
                    f'\n{affine}'
444
                    f'\nMatrix of {path}:'
445
                    f'\n{new_affine}'
446
                )
447
                warnings.warn(message, RuntimeWarning)
448
            if not tensor.shape[1:] == new_tensor.shape[1:]:
449
                message = (
450
                    f'Files shape do not match, found {tensor.shape}'
451
                    f'and {new_tensor.shape}'
452
                )
453
                RuntimeError(message)
454
            tensors.append(new_tensor)
455
        tensor = torch.cat(tensors)
456
        self.set_data(tensor)
457
        self.affine = affine
458
        self._loaded = True
459
460
    def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
461
        tensor, affine = self.reader(path)
462
        tensor = self.parse_tensor_shape(tensor)
463
        tensor = self._parse_tensor(tensor)
464
        affine = self._parse_affine(affine)
465
        if self.channels_last:
466
            tensor = tensor.permute(3, 0, 1, 2)
467
        if self.check_nans and torch.isnan(tensor).any():
468
            warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
469
        return tensor, affine
470
471
    def save(self, path: TypePath, squeeze: bool = True) -> None:
472
        """Save image to disk.
473
474
        Args:
475
            path: String or instance of :class:`pathlib.Path`.
476
            squeeze: If ``True``, singleton dimensions will be removed
477
                before saving.
478
        """
479
        write_image(
480
            self.data,
481
            self.affine,
482
            path,
483
            squeeze=squeeze,
484
        )
485
486
    def is_2d(self) -> bool:
487
        return self.shape[-1] == 1
488
489
    def numpy(self) -> np.ndarray:
490
        """Get a NumPy array containing the image data."""
491
        return np.asarray(self)
492
493
    def as_sitk(self, **kwargs) -> sitk.Image:
494
        """Get the image as an instance of :class:`sitk.Image`."""
495
        return nib_to_sitk(self.data, self.affine, **kwargs)
496
497
    def as_pil(self, transpose=True):
498
        """Get the image as an instance of :class:`PIL.Image`.
499
500
        .. note:: Values will be clamped to 0-255 and cast to uint8.
501
        .. note:: To use this method, `Pillow` needs to be installed:
502
            `pip install Pillow`.
503
        """
504
        try:
505
            from PIL import Image as ImagePIL
506
        except ModuleNotFoundError as e:
507
            message = (
508
                'Please install Pillow to use Image.as_pil():'
509
                ' pip install Pillow'
510
            )
511
            raise RuntimeError(message) from e
512
513
        self.check_is_2d()
514
        tensor = self.data
515
        if len(tensor) == 1:
516
            tensor = torch.cat(3 * [tensor])
517
        if len(tensor) != 3:
518
            raise RuntimeError('The image must have 1 or 3 channels')
519
        if transpose:
520
            tensor = tensor.permute(3, 2, 1, 0)
521
        else:
522
            tensor = tensor.permute(3, 1, 2, 0)
523
        array = tensor.clamp(0, 255).numpy()[0]
524
        return ImagePIL.fromarray(array.astype(np.uint8))
525
526
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
527
        """Get image center in RAS+ or LPS+ coordinates.
528
529
        Args:
530
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
531
                the first dimension grows towards the left, etc. Otherwise, the
532
                coordinates will be in RAS+ orientation.
533
        """
534
        size = np.array(self.spatial_shape)
535
        center_index = (size - 1) / 2
536
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
537
        if lps:
538
            return (-r, -a, s)
539
        else:
540
            return (r, a, s)
541
542
    def set_check_nans(self, check_nans: bool) -> None:
543
        self.check_nans = check_nans
544
545
    def plot(self, **kwargs) -> None:
546
        """Plot image."""
547
        if self.is_2d():
548
            self.as_pil().show()
549
        else:
550
            from ..visualization import plot_volume  # avoid circular import
551
            plot_volume(self, **kwargs)
552
553
554
class ScalarImage(Image):
555
    """Image whose pixel values represent scalars.
556
557
    Example:
558
        >>> import torch
559
        >>> import torchio as tio
560
        >>> # Loading from a file
561
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
562
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
563
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
564
        >>> data, affine = image.data, image.affine
565
        >>> affine.shape
566
        (4, 4)
567
        >>> image.data is image[tio.DATA]
568
        True
569
        >>> image.data is image.tensor
570
        True
571
        >>> type(image.data)
572
        torch.Tensor
573
574
    See :class:`~torchio.Image` for more information.
575
    """
576
    def __init__(self, *args, **kwargs):
577
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
578
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
579
        kwargs.update({'type': INTENSITY})
580
        super().__init__(*args, **kwargs)
581
582
583
class LabelMap(Image):
584
    """Image whose pixel values represent categorical labels.
585
586
    Intensity transforms are not applied to these images.
587
588
    Example:
589
        >>> import torch
590
        >>> import torchio as tio
591
        >>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5)
592
        >>> labels = tio.LabelMap('t1_seg.nii.gz')  # loading from a file
593
        >>> tpm = tio.LabelMap(                     # loading from files
594
        ...     'gray_matter.nii.gz',
595
        ...     'white_matter.nii.gz',
596
        ...     'csf.nii.gz',
597
        ... )
598
599
    See :class:`~torchio.Image` for more information.
600
    """
601
    def __init__(self, *args, **kwargs):
602
        if 'type' in kwargs and kwargs['type'] != LABEL:
603
            raise ValueError('Type of LabelMap is always torchio.LABEL')
604
        kwargs.update({'type': LABEL})
605
        super().__init__(*args, **kwargs)
606