Passed
Push — master ( 2c67ba...a6348a )
by Fernando
01:11
created

torchio.data.image.Image.flip_axis()   C

Complexity

Conditions 9

Size

Total Lines 15
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 14
nop 1
dl 0
loc 15
rs 6.6666
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
from PIL import Image as ImagePIL
9
import nibabel as nib
10
import SimpleITK as sitk
11
12
from ..utils import get_stem
13
from ..typing import TypeData, TypePath, TypeTripletInt, TypeTripletFloat
14
from ..constants import DATA, TYPE, AFFINE, PATH, STEM, INTENSITY, LABEL
15
from .io import (
16
    ensure_4d,
17
    read_image,
18
    write_image,
19
    nib_to_sitk,
20
    check_uint_to_int,
21
    get_rotation_and_spacing_from_affine,
22
)
23
24
25
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
26
TypeBound = Tuple[float, float]
27
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
28
29
30
class Image(dict):
31
    r"""TorchIO image.
32
33
    For information about medical image orientation, check out `NiBabel docs`_,
34
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
35
    `SimpleITK docs`_.
36
37
    Args:
38
        path: Path to a file or sequence of paths to files that can be read by
39
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
40
            DICOM files. If :attr:`tensor` is given, the data in
41
            :attr:`path` will not be read.
42
            If a sequence of paths is given, data
43
            will be concatenated on the channel dimension so spatial
44
            dimensions must match.
45
        type: Type of image, such as :attr:`torchio.INTENSITY` or
46
            :attr:`torchio.LABEL`. This will be used by the transforms to
47
            decide whether to apply an operation, or which interpolation to use
48
            when resampling. For example, `preprocessing`_ and `augmentation`_
49
            intensity transforms will only be applied to images with type
50
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
51
            all types, and nearest neighbor interpolation is always used to
52
            resample images with type :attr:`torchio.LABEL`.
53
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
54
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
55
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
56
            :class:`torch.Tensor` or NumPy array with dimensions
57
            :math:`(C, W, H, D)`.
58
        affine: If :attr:`path` is not given, :attr:`affine` must be a
59
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
60
            identity matrix.
61
        check_nans: If ``True``, issues a warning if NaNs are found
62
            in the image. If ``False``, images will not be checked for the
63
            presence of NaNs.
64
        **kwargs: Items that will be added to the image dictionary, e.g.
65
            acquisition parameters.
66
67
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
68
    when needed.
69
70
    Example:
71
        >>> import torchio as tio
72
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
73
        >>> image  # not loaded yet
74
        ScalarImage(path: t1.nii.gz; type: intensity)
75
        >>> times_two = 2 * image.data  # data is loaded and cached here
76
        >>> image
77
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
78
        >>> image.save('doubled_image.nii.gz')
79
80
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
81
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
82
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
83
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
84
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
85
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
86
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
87
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
88
    """
89
    def __init__(
90
            self,
91
            path: Union[TypePath, Sequence[TypePath], None] = None,
92
            type: str = None,
93
            tensor: Optional[TypeData] = None,
94
            affine: Optional[TypeData] = None,
95
            check_nans: bool = False,  # removed by ITK by default
96
            channels_last: bool = False,
97
            **kwargs: Dict[str, Any],
98
            ):
99
        self.check_nans = check_nans
100
        self.channels_last = channels_last
101
102
        if type is None:
103
            warnings.warn(
104
                'Not specifying the image type is deprecated and will be'
105
                ' mandatory in the future. You can probably use tio.ScalarImage'
106
                'or LabelMap instead', DeprecationWarning,
107
            )
108
            type = INTENSITY
109
110
        if path is None and tensor is None:
111
            raise ValueError('A value for path or tensor must be given')
112
        self._loaded = False
113
114
        tensor = self._parse_tensor(tensor)
115
        affine = self._parse_affine(affine)
116
        if tensor is not None:
117
            self.data = tensor
118
            self.affine = affine
119
            self._loaded = True
120
        for key in PROTECTED_KEYS:
121
            if key in kwargs:
122
                message = f'Key "{key}" is reserved. Use a different one'
123
                raise ValueError(message)
124
125
        super().__init__(**kwargs)
126
        self.path = self._parse_path(path)
127
128
        self[PATH] = '' if self.path is None else str(self.path)
129
        self[STEM] = '' if self.path is None else get_stem(self.path)
130
        self[TYPE] = type
131
132
    def __repr__(self):
133
        properties = []
134
        if self._loaded:
135
            properties.extend([
136
                f'shape: {self.shape}',
137
                f'spacing: {self.get_spacing_string()}',
138
                f'orientation: {"".join(self.orientation)}+',
139
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
140
            ])
141
        else:
142
            properties.append(f'path: "{self.path}"')
143
        if self._loaded:
144
            properties.append(f'dtype: {self.data.type()}')
145
        properties = '; '.join(properties)
146
        string = f'{self.__class__.__name__}({properties})'
147
        return string
148
149
    def __getitem__(self, item):
150
        if item in (DATA, AFFINE):
151
            if item not in self:
152
                self.load()
153
        return super().__getitem__(item)
154
155
    def __array__(self):
156
        return self.data.numpy()
157
158
    def __copy__(self):
159
        kwargs = dict(
160
            tensor=self.data,
161
            affine=self.affine,
162
            type=self.type,
163
            path=self.path,
164
        )
165
        for key, value in self.items():
166
            if key in PROTECTED_KEYS: continue
167
            kwargs[key] = value  # should I copy? deepcopy?
168
        return self.__class__(**kwargs)
169
170
    @property
171
    def data(self) -> torch.Tensor:
172
        """Tensor data. Same as :class:`Image.tensor`."""
173
        return self[DATA]
174
175
    @data.setter
176
    def data(self, tensor: TypeData):
177
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
178
179
    @property
180
    def tensor(self) -> torch.Tensor:
181
        """Tensor data. Same as :class:`Image.data`."""
182
        return self.data
183
184
    @property
185
    def affine(self) -> np.ndarray:
186
        """Affine matrix to transform voxel indices into world coordinates."""
187
        return self[AFFINE]
188
189
    @affine.setter
190
    def affine(self, matrix):
191
        self[AFFINE] = self._parse_affine(matrix)
192
193
    @property
194
    def type(self) -> str:
195
        return self[TYPE]
196
197
    @property
198
    def shape(self) -> Tuple[int, int, int, int]:
199
        """Tensor shape as :math:`(C, W, H, D)`."""
200
        return tuple(self.data.shape)
201
202
    @property
203
    def spatial_shape(self) -> TypeTripletInt:
204
        """Tensor spatial shape as :math:`(W, H, D)`."""
205
        return self.shape[1:]
206
207
    def check_is_2d(self) -> None:
208
        if not self.is_2d():
209
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
210
            raise RuntimeError(message)
211
212
    @property
213
    def height(self) -> int:
214
        """Image height, if 2D."""
215
        self.check_is_2d()
216
        return self.spatial_shape[1]
217
218
    @property
219
    def width(self) -> int:
220
        """Image width, if 2D."""
221
        self.check_is_2d()
222
        return self.spatial_shape[0]
223
224
    @property
225
    def orientation(self) -> Tuple[str, str, str]:
226
        """Orientation codes."""
227
        return nib.aff2axcodes(self.affine)
228
229
    @property
230
    def spacing(self) -> Tuple[float, float, float]:
231
        """Voxel spacing in mm."""
232
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
233
        return tuple(spacing)
234
235
    @property
236
    def memory(self) -> float:
237
        """Number of Bytes that the tensor takes in the RAM."""
238
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
239
240
    @property
241
    def bounds(self) -> np.ndarray:
242
        """Position of centers of voxels in smallest and largest coordinates."""
243
        ini = 0, 0, 0
244
        fin = np.array(self.spatial_shape) - 1
245
        point_ini = nib.affines.apply_affine(self.affine, ini)
246
        point_fin = nib.affines.apply_affine(self.affine, fin)
247
        return np.array((point_ini, point_fin))
248
249
    def axis_name_to_index(self, axis: str) -> int:
250
        """Convert an axis name to an axis index.
251
252
        Args:
253
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
254
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
255
                versions and first letters are also valid, as only the first
256
                letter will be used.
257
258
        .. note:: If you are working with animals, you should probably use
259
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
260
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
261
            respectively.
262
263
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
264
            ``'Left'`` and ``'Right'``.
265
        """
266
        # Top and bottom are used for the vertical 2D axis as the use of
267
        # Height vs Horizontal might be ambiguous
268
269
        if not isinstance(axis, str):
270
            raise ValueError('Axis must be a string')
271
        axis = axis[0].upper()
272
273
        # Generally, TorchIO tensors are (C, W, H, D)
274
        if axis in 'TB':  # Top, Bottom
275
            return -2
276
        else:
277
            try:
278
                index = self.orientation.index(axis)
279
            except ValueError:
280
                index = self.orientation.index(self.flip_axis(axis))
281
            # Return negative indices so that it does not matter whether we
282
            # refer to spatial dimensions or not
283
            index = -3 + index
284
            return index
285
286
    # flake8: noqa: E701
287
    @staticmethod
288
    def flip_axis(axis: str) -> str:
289
        if axis == 'R': flipped_axis = 'L'
290
        elif axis == 'L': flipped_axis = 'R'
291
        elif axis == 'A': flipped_axis = 'P'
292
        elif axis == 'P': flipped_axis = 'A'
293
        elif axis == 'I': flipped_axis = 'S'
294
        elif axis == 'S': flipped_axis = 'I'
295
        elif axis == 'T': flipped_axis = 'B'
296
        elif axis == 'B': flipped_axis = 'T'
297
        else:
298
            values = ', '.join('LRPAISTB')
299
            message = f'Axis not understood. Please use one of: {values}'
300
            raise ValueError(message)
301
        return flipped_axis
302
303
    def get_spacing_string(self) -> str:
304
        strings = [f'{n:.2f}' for n in self.spacing]
305
        string = f'({", ".join(strings)})'
306
        return string
307
308
    def get_bounds(self) -> TypeBounds:
309
        """Get minimum and maximum world coordinates occupied by the image."""
310
        first_index = 3 * (-0.5,)
311
        last_index = np.array(self.spatial_shape) - 0.5
312
        first_point = nib.affines.apply_affine(self.affine, first_index)
313
        last_point = nib.affines.apply_affine(self.affine, last_index)
314
        array = np.array((first_point, last_point))
315
        bounds_x, bounds_y, bounds_z = array.T.tolist()
316
        return bounds_x, bounds_y, bounds_z
317
318
    @staticmethod
319
    def _parse_single_path(
320
            path: TypePath
321
            ) -> Path:
322
        try:
323
            path = Path(path).expanduser()
324
        except TypeError:
325
            message = (
326
                f'Expected type str or Path but found {path} with '
327
                f'{type(path)} instead'
328
            )
329
            raise TypeError(message)
330
        except RuntimeError:
331
            message = (
332
                f'Conversion to path not possible for variable: {path}'
333
            )
334
            raise RuntimeError(message)
335
336
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
337
            raise FileNotFoundError(f'File not found: "{path}"')
338
        return path
339
340
    def _parse_path(
341
            self,
342
            path: Union[TypePath, Sequence[TypePath]]
343
            ) -> Optional[Union[Path, List[Path]]]:
344
        if path is None:
345
            return None
346
        if isinstance(path, (str, Path)):
347
            return self._parse_single_path(path)
348
        else:
349
            return [self._parse_single_path(p) for p in path]
350
351
    def _parse_tensor(
352
            self,
353
            tensor: TypeData,
354
            none_ok: bool = True,
355
            ) -> torch.Tensor:
356
        if tensor is None:
357
            if none_ok:
358
                return None
359
            else:
360
                raise RuntimeError('Input tensor cannot be None')
361
        if isinstance(tensor, np.ndarray):
362
            tensor = check_uint_to_int(tensor)
363
            tensor = torch.from_numpy(tensor)
364
        elif not isinstance(tensor, torch.Tensor):
365
            message = 'Input tensor must be a PyTorch tensor or NumPy array'
366
            raise TypeError(message)
367
        if tensor.ndim != 4:
368
            raise ValueError('Input tensor must be 4D')
369
        if tensor.dtype == torch.bool:
370
            tensor = tensor.to(torch.uint8)
371
        if self.check_nans and torch.isnan(tensor).any():
372
            warnings.warn(f'NaNs found in tensor', RuntimeWarning)
373
        return tensor
374
375
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
376
        return ensure_4d(tensor)
377
378
    @staticmethod
379
    def _parse_affine(affine: np.ndarray) -> np.ndarray:
380
        if affine is None:
381
            return np.eye(4)
382
        if not isinstance(affine, np.ndarray):
383
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
384
        if affine.shape != (4, 4):
385
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
386
        return affine
387
388
    def load(self) -> None:
389
        r"""Load the image from disk.
390
391
        Returns:
392
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
393
            :math:`4 \times 4` affine matrix to convert voxel indices to world
394
            coordinates.
395
        """
396
        if self._loaded:
397
            return
398
        paths = self.path if isinstance(self.path, list) else [self.path]
399
        tensor, affine = self.read_and_check(paths[0])
400
        tensors = [tensor]
401
        for path in paths[1:]:
402
            new_tensor, new_affine = self.read_and_check(path)
403
            if not np.array_equal(affine, new_affine):
404
                message = (
405
                    'Files have different affine matrices.'
406
                    f'\nMatrix of {paths[0]}:'
407
                    f'\n{affine}'
408
                    f'\nMatrix of {path}:'
409
                    f'\n{new_affine}'
410
                )
411
                warnings.warn(message, RuntimeWarning)
412
            if not tensor.shape[1:] == new_tensor.shape[1:]:
413
                message = (
414
                    f'Files shape do not match, found {tensor.shape}'
415
                    f'and {new_tensor.shape}'
416
                )
417
                RuntimeError(message)
418
            tensors.append(new_tensor)
419
        tensor = torch.cat(tensors)
420
        self.data = tensor
421
        self.affine = affine
422
        self._loaded = True
423
424
    def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
425
        tensor, affine = read_image(path)
426
        tensor = self.parse_tensor_shape(tensor)
427
        if self.channels_last:
428
            tensor = tensor.permute(3, 0, 1, 2)
429
        if self.check_nans and torch.isnan(tensor).any():
430
            warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
431
        return tensor, affine
432
433
    def save(self, path: TypePath, squeeze: bool = True) -> None:
434
        """Save image to disk.
435
436
        Args:
437
            path: String or instance of :class:`pathlib.Path`.
438
            squeeze: If ``True``, singleton dimensions will be removed
439
                before saving.
440
        """
441
        write_image(
442
            self.data,
443
            self.affine,
444
            path,
445
            squeeze=squeeze,
446
        )
447
448
    def is_2d(self) -> bool:
449
        return self.shape[-1] == 1
450
451
    def numpy(self) -> np.ndarray:
452
        """Get a NumPy array containing the image data."""
453
        return np.asarray(self)
454
455
    def as_sitk(self, **kwargs) -> sitk.Image:
456
        """Get the image as an instance of :class:`sitk.Image`."""
457
        return nib_to_sitk(self.data, self.affine, **kwargs)
458
459
    def as_pil(self) -> ImagePIL:
460
        """Get the image as an instance of :class:`PIL.Image`.
461
462
        .. note:: Values will be clamped to 0-255 and cast to uint8.
463
        """
464
        self.check_is_2d()
465
        tensor = self.data
466
        if len(tensor) == 1:
467
            tensor = torch.cat(3 * [tensor])
468
        if len(tensor) != 3:
469
            raise RuntimeError('The image must have 1 or 3 channels')
470
        tensor = tensor.permute(3, 1, 2, 0)[0]
471
        array = tensor.clamp(0, 255).numpy()
472
        return ImagePIL.fromarray(array.astype(np.uint8))
473
474
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
475
        """Get image center in RAS+ or LPS+ coordinates.
476
477
        Args:
478
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
479
                the first dimension grows towards the left, etc. Otherwise, the
480
                coordinates will be in RAS+ orientation.
481
        """
482
        size = np.array(self.spatial_shape)
483
        center_index = (size - 1) / 2
484
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
485
        if lps:
486
            return (-r, -a, s)
487
        else:
488
            return (r, a, s)
489
490
    def set_check_nans(self, check_nans: bool) -> None:
491
        self.check_nans = check_nans
492
493
    def plot(self, **kwargs) -> None:
494
        """Plot image."""
495
        if self.is_2d():
496
            self.as_pil().show()
497
        else:
498
            from ..visualization import plot_volume  # avoid circular import
499
            plot_volume(self, **kwargs)
500
501
502
class ScalarImage(Image):
503
    """Image whose pixel values represent scalars.
504
505
    Example:
506
        >>> import torch
507
        >>> import torchio as tio
508
        >>> # Loading from a file
509
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
510
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
511
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
512
        >>> data, affine = image.data, image.affine
513
        >>> affine.shape
514
        (4, 4)
515
        >>> image.data is image[tio.DATA]
516
        True
517
        >>> image.data is image.tensor
518
        True
519
        >>> type(image.data)
520
        torch.Tensor
521
522
    See :class:`~torchio.Image` for more information.
523
    """
524
    def __init__(self, *args, **kwargs):
525
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
526
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
527
        kwargs.update({'type': INTENSITY})
528
        super().__init__(*args, **kwargs)
529
530
531
class LabelMap(Image):
532
    """Image whose pixel values represent categorical labels.
533
534
    Intensity transforms are not applied to these images.
535
536
    Example:
537
        >>> import torch
538
        >>> import torchio as tio
539
        >>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5)
540
        >>> labels = tio.LabelMap('t1_seg.nii.gz')  # loading from a file
541
        >>> tpm = tio.LabelMap(                     # loading from files
542
        ...     'gray_matter.nii.gz',
543
        ...     'white_matter.nii.gz',
544
        ...     'csf.nii.gz',
545
        ... )
546
547
    See :class:`~torchio.Image` for more information.
548
    """
549
    def __init__(self, *args, **kwargs):
550
        if 'type' in kwargs and kwargs['type'] != LABEL:
551
            raise ValueError('Type of LabelMap is always torchio.LABEL')
552
        kwargs.update({'type': LABEL})
553
        super().__init__(*args, **kwargs)
554