Passed
Pull Request — master (#332)
by Fernando
03:27 queued 02:14
created

torchio.data.image.Image.plot()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file or sequence of paths to files that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read.
48
            If a sequence of paths is given, data
49
            will be concatenated on the channel dimension so spatial
50
            dimensions must match.
51
        type: Type of image, such as :attr:`torchio.INTENSITY` or
52
            :attr:`torchio.LABEL`. This will be used by the transforms to
53
            decide whether to apply an operation, or which interpolation to use
54
            when resampling. For example, `preprocessing`_ and `augmentation`_
55
            intensity transforms will only be applied to images with type
56
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
57
            all types, and nearest neighbor interpolation is always used to
58
            resample images with type :attr:`torchio.LABEL`.
59
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
60
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
61
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
62
            :py:class:`torch.Tensor` or NumPy array with dimensions
63
            :math:`(C, W, H, D)`.
64
        affine: If :attr:`path` is not given, :attr:`affine` must be a
65
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
66
            identity matrix.
67
        check_nans: If ``True``, issues a warning if NaNs are found
68
            in the image. If ``False``, images will not be checked for the
69
            presence of NaNs.
70
        **kwargs: Items that will be added to the image dictionary, e.g.
71
            acquisition parameters.
72
73
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
74
    when needed.
75
76
    Example:
77
        >>> import torchio
78
        >>> image = torchio.ScalarImage('t1.nii.gz')  # subclass of Image
79
        >>> image  # not loaded yet
80
        ScalarImage(path: t1.nii.gz; type: intensity)
81
        >>> times_two = 2 * image.data  # data is loaded and cached here
82
        >>> image
83
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
84
        >>> image.save('doubled_image.nii.gz')
85
86
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
87
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
88
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
89
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
90
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
91
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
92
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
93
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
94
    """
95
    def __init__(
96
            self,
97
            path: Union[TypePath, Sequence[TypePath], None] = None,
98
            type: str = None,
99
            tensor: Optional[TypeData] = None,
100
            affine: Optional[TypeData] = None,
101
            check_nans: bool = False,  # removed by ITK by default
102
            channels_last: bool = False,
103
            **kwargs: Dict[str, Any],
104
            ):
105
        self.check_nans = check_nans
106
        self.channels_last = channels_last
107
108
        if type is None:
109
            warnings.warn(
110
                'Not specifying the image type is deprecated and will be'
111
                ' mandatory in the future. You can probably use ScalarImage or'
112
                ' LabelMap instead'
113
            )
114
            type = INTENSITY
115
116
        if path is None and tensor is None:
117
            raise ValueError('A value for path or tensor must be given')
118
        self._loaded = False
119
120
        tensor = self.parse_tensor(tensor)
121
        affine = self.parse_affine(affine)
122
        if tensor is not None:
123
            self[DATA] = tensor
124
            self[AFFINE] = affine
125
            self._loaded = True
126
        for key in PROTECTED_KEYS:
127
            if key in kwargs:
128
                message = f'Key "{key}" is reserved. Use a different one'
129
                raise ValueError(message)
130
131
        super().__init__(**kwargs)
132
        self.path = self._parse_path(path)
133
134
        self[PATH] = '' if self.path is None else str(self.path)
135
        self[STEM] = '' if self.path is None else get_stem(self.path)
136
        self[TYPE] = type
137
138
    def __repr__(self):
139
        properties = []
140
        if self._loaded:
141
            properties.extend([
142
                f'shape: {self.shape}',
143
                f'spacing: {self.get_spacing_string()}',
144
                f'orientation: {"".join(self.orientation)}+',
145
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
146
            ])
147
        else:
148
            properties.append(f'path: "{self.path}"')
149
        properties.append(f'type: {self.type}')
150
        properties = '; '.join(properties)
151
        string = f'{self.__class__.__name__}({properties})'
152
        return string
153
154
    def __getitem__(self, item):
155
        if item in (DATA, AFFINE):
156
            if item not in self:
157
                self.load()
158
        return super().__getitem__(item)
159
160
    def __array__(self):
161
        return self[DATA].numpy()
162
163
    def __copy__(self):
164
        kwargs = dict(
165
            tensor=self.data,
166
            affine=self.affine,
167
            type=self.type,
168
            path=self.path,
169
        )
170
        for key, value in self.items():
171
            if key in PROTECTED_KEYS: continue
172
            kwargs[key] = value  # should I copy? deepcopy?
173
        return self.__class__(**kwargs)
174
175
    @property
176
    def data(self) -> torch.Tensor:
177
        return self[DATA]
178
179
    @property
180
    def tensor(self) -> torch.Tensor:
181
        return self.data
182
183
    @property
184
    def affine(self) -> np.ndarray:
185
        return self[AFFINE]
186
187
    @property
188
    def type(self) -> str:
189
        return self[TYPE]
190
191
    @property
192
    def shape(self) -> Tuple[int, int, int, int]:
193
        return tuple(self.data.shape)
194
195
    @property
196
    def spatial_shape(self) -> TypeTripletInt:
197
        return self.shape[1:]
198
199
    def check_is_2d(self) -> None:
200
        if not self.is_2d():
201
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
202
            raise RuntimeError(message)
203
204
    @property
205
    def height(self) -> int:
206
        self.check_is_2d()
207
        return self.spatial_shape[1]
208
209
    @property
210
    def width(self) -> int:
211
        self.check_is_2d()
212
        return self.spatial_shape[0]
213
214
    @property
215
    def orientation(self) -> Tuple[str, str, str]:
216
        return nib.aff2axcodes(self.affine)
217
218
    @property
219
    def spacing(self) -> Tuple[float, float, float]:
220
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
221
        return tuple(spacing)
222
223
    @property
224
    def memory(self) -> float:
225
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
226
227
    @property
228
    def bounds(self) -> np.ndarray:
229
        ini = 0, 0, 0
230
        fin = np.array(self.spatial_shape) - 1
231
        point_ini = nib.affines.apply_affine(self.affine, ini)
232
        point_fin = nib.affines.apply_affine(self.affine, fin)
233
        return np.array((point_ini, point_fin))
234
235
    def axis_name_to_index(self, axis: str):
236
        """Convert an axis name to an axis index.
237
238
        Args:
239
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
240
            ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case versions
241
            and first letters are also valid, as only the first letter will be
242
            used.
243
244
        .. note:: If you are working with animals, you should probably use
245
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
246
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
247
            respectively.
248
249
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
250
            ``'Left'`` and ``'Right'``.
251
        """
252
        # Top and bottom are used for the vertical 2D axis as the use of
253
        # Height vs Horizontal might be ambiguous
254
255
        if not isinstance(axis, str):
256
            raise ValueError('Axis must be a string')
257
        axis = axis[0].upper()
258
259
        # Generally, TorchIO tensors are (C, W, H, D)
260
        if axis in 'TB':  # Top, Bottom
261
            return -2
262
        else:
263
            try:
264
                index = self.orientation.index(axis)
265
            except ValueError:
266
                index = self.orientation.index(self.flip_axis(axis))
267
            # Return negative indices so that it does not matter whether we
268
            # refer to spatial dimensions or not
269
            index = -3 + index
270
            return index
271
272
    # flake8: noqa: E701
273
    @staticmethod
274
    def flip_axis(axis):
275
        if axis == 'R': return 'L'
276
        elif axis == 'L': return 'R'
277
        elif axis == 'A': return 'P'
278
        elif axis == 'P': return 'A'
279
        elif axis == 'I': return 'S'
280
        elif axis == 'S': return 'I'
281
        else:
282
            values = ', '.join('LRPAISTB')
283
            message = f'Axis not understood. Please use one of: {values}'
284
            raise ValueError(message)
285
286
    def get_spacing_string(self):
287
        strings = [f'{n:.2f}' for n in self.spacing]
288
        string = f'({", ".join(strings)})'
289
        return string
290
291
    def get_bounds(self):
292
        """Get image bounds in mm."""
293
        first_index = 3 * (-0.5,)
294
        last_index = np.array(self.spatial_shape) - 0.5
295
        first_point = nib.affines.apply_affine(self.affine, first_index)
296
        last_point = nib.affines.apply_affine(self.affine, last_index)
297
        array = np.array((first_point, last_point))
298
        bounds_x, bounds_y, bounds_z = array.T.tolist()
299
        return bounds_x, bounds_y, bounds_z
300
301
    @staticmethod
302
    def _parse_single_path(
303
            path: TypePath
304
            ) -> Path:
305
        try:
306
            path = Path(path).expanduser()
307
        except TypeError:
308
            message = (
309
                f'Expected type str or Path but found {path} with '
310
                f'{type(path)} instead'
311
            )
312
            raise TypeError(message)
313
        except RuntimeError:
314
            message = (
315
                f'Conversion to path not possible for variable: {path}'
316
            )
317
            raise RuntimeError(message)
318
319
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
320
            raise FileNotFoundError(f'File not found: "{path}"')
321
        return path
322
323
    def _parse_path(
324
            self,
325
            path: Union[TypePath, Sequence[TypePath]]
326
            ) -> Union[Path, List[Path]]:
327
        if path is None:
328
            return None
329
        if isinstance(path, (str, Path)):
330
            return self._parse_single_path(path)
331
        else:
332
            return [self._parse_single_path(p) for p in path]
333
334
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
335
        if tensor is None:
336
            return None
337
        if isinstance(tensor, np.ndarray):
338
            tensor = torch.from_numpy(tensor.astype(np.float32))
339
        elif isinstance(tensor, torch.Tensor):
340
            tensor = tensor.float()
341
        if tensor.ndim != 4:
342
            raise ValueError('Input tensor must be 4D')
343
        if self.check_nans and torch.isnan(tensor).any():
344
            warnings.warn(f'NaNs found in tensor')
345
        return tensor
346
347
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
348
        return ensure_4d(tensor)
349
350
    @staticmethod
351
    def parse_affine(affine: np.ndarray) -> np.ndarray:
352
        if affine is None:
353
            return np.eye(4)
354
        if not isinstance(affine, np.ndarray):
355
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
356
        if affine.shape != (4, 4):
357
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
358
        return affine
359
360
    def load(self) -> None:
361
        r"""Load the image from disk.
362
363
        Returns:
364
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
365
            :math:`4 \times 4` affine matrix to convert voxel indices to world
366
            coordinates.
367
        """
368
        if self._loaded:
369
            return
370
        paths = self.path if isinstance(self.path, list) else [self.path]
371
        tensor, affine = self.read_and_check(paths[0])
372
        tensors = [tensor]
373
        for path in paths[1:]:
374
            new_tensor, new_affine = self.read_and_check(path)
375
            if not np.array_equal(affine, new_affine):
376
                message = (
377
                    'Files have different affine matrices.'
378
                    f'\nMatrix of {paths[0]}:'
379
                    f'\n{affine}'
380
                    f'\nMatrix of {path}:'
381
                    f'\n{new_affine}'
382
                )
383
                warnings.warn(message, RuntimeWarning)
384
            if not tensor.shape[1:] == new_tensor.shape[1:]:
385
                message = (
386
                    f'Files shape do not match, found {tensor.shape}'
387
                    f'and {new_tensor.shape}'
388
                )
389
                RuntimeError(message)
390
            tensors.append(new_tensor)
391
        tensor = torch.cat(tensors)
392
        self[DATA] = tensor
393
        self[AFFINE] = affine
394
        self._loaded = True
395
396
    def read_and_check(self, path):
397
        tensor, affine = read_image(path)
398
        tensor = self.parse_tensor_shape(tensor)
399
        if self.channels_last:
400
            tensor = tensor.permute(3, 0, 1, 2)
401
        if self.check_nans and torch.isnan(tensor).any():
402
            warnings.warn(f'NaNs found in file "{path}"')
403
        return tensor, affine
404
405
    def save(self, path: TypePath, squeeze: bool = True):
406
        """Save image to disk.
407
408
        Args:
409
            path: String or instance of :py:class:`pathlib.Path`.
410
            squeeze: If ``True``, the singleton dimensions will be removed
411
                before saving.
412
        """
413
        write_image(
414
            self[DATA],
415
            self[AFFINE],
416
            path,
417
            squeeze=squeeze,
418
        )
419
420
    def is_2d(self) -> bool:
421
        return self.shape[-1] == 1
422
423
    def numpy(self) -> np.ndarray:
424
        """Get a NumPy array containing the image data."""
425
        return np.asarray(self)
426
427
    def as_sitk(self, **kwargs) -> sitk.Image:
428
        """Get the image as an instance of :py:class:`sitk.Image`."""
429
        return nib_to_sitk(self[DATA], self[AFFINE], **kwargs)
430
431
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
432
        """Get image center in RAS+ or LPS+ coordinates.
433
434
        Args:
435
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
436
                the first dimension grows towards the left, etc. Otherwise, the
437
                coordinates will be in RAS+ orientation.
438
        """
439
        size = np.array(self.spatial_shape)
440
        center_index = (size - 1) / 2
441
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
442
        if lps:
443
            return (-r, -a, s)
444
        else:
445
            return (r, a, s)
446
447
    def set_check_nans(self, check_nans: bool):
448
        self.check_nans = check_nans
449
450
    def plot(self, **kwargs) -> None:
451
        from ..visualization import plot_image  # avoid circular import
452
        plot_image(self, **kwargs)
453
454
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
455
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
456
        new_affine = self.affine.copy()
457
        new_affine[:3, 3] = new_origin
458
        i0, j0, k0 = index_ini
459
        i1, j1, k1 = index_fin
460
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
461
        kwargs = dict(
462
            tensor=patch,
463
            affine=new_affine,
464
            type=self.type,
465
            path=self.path,
466
        )
467
        for key, value in self.items():
468
            if key in PROTECTED_KEYS: continue
469
            kwargs[key] = value  # should I copy? deepcopy?
470
        return self.__class__(**kwargs)
471
472
473
class ScalarImage(Image):
474
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
475
476
    Example:
477
        >>> import torch
478
        >>> import torchio
479
        >>> # Loading from a file
480
        >>> t1_image = torchio.ScalarImage('t1.nii.gz')
481
        >>> dmri = torchio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
482
        >>> image = torchio.ScalarImage('safe_image.nrrd', check_nans=False)
483
        >>> data, affine = image.data, image.affine
484
        >>> affine.shape
485
        (4, 4)
486
        >>> image.data is image[torchio.DATA]
487
        True
488
        >>> image.data is image.tensor
489
        True
490
        >>> type(image.data)
491
        torch.Tensor
492
493
    See :py:class:`~torchio.Image` for more information.
494
495
    Raises:
496
        ValueError: A :py:attr:`type` is used for instantiation.
497
    """
498
    def __init__(self, *args, **kwargs):
499
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
500
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
501
        kwargs.update({'type': INTENSITY})
502
        super().__init__(*args, **kwargs)
503
504
505
class LabelMap(Image):
506
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
507
508
    Example:
509
        >>> import torch
510
        >>> import torchio
511
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
512
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
513
        >>> tpm = torchio.LabelMap(                     # loading from files
514
        ...     'gray_matter.nii.gz',
515
        ...     'white_matter.nii.gz',
516
        ...     'csf.nii.gz',
517
        ... )
518
519
    See :py:class:`~torchio.data.image.Image` for more information.
520
521
    Raises:
522
        ValueError: If a value for :py:attr:`type` is given.
523
    """
524
    def __init__(self, *args, **kwargs):
525
        if 'type' in kwargs and kwargs['type'] != LABEL:
526
            raise ValueError('Type of LabelMap is always torchio.LABEL')
527
        kwargs.update({'type': LABEL})
528
        super().__init__(*args, **kwargs)
529