Passed
Pull Request — master (#334)
by Fernando
01:13
created

torchio.data.image.Image.affine()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
from PIL import Image as ImagePIL
9
import nibabel as nib
10
import SimpleITK as sitk
11
12
from ..utils import (
13
    nib_to_sitk,
14
    get_rotation_and_spacing_from_affine,
15
    get_stem,
16
    ensure_4d,
17
)
18
from ..torchio import (
19
    TypeData,
20
    TypePath,
21
    TypeTripletInt,
22
    TypeTripletFloat,
23
    DATA,
24
    TYPE,
25
    AFFINE,
26
    PATH,
27
    STEM,
28
    INTENSITY,
29
    LABEL,
30
)
31
from .io import read_image, write_image
32
33
34
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
35
36
37
class Image(dict):
38
    r"""TorchIO image.
39
40
    For information about medical image orientation, check out `NiBabel docs`_,
41
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
42
    `SimpleITK docs`_.
43
44
    Args:
45
        path: Path to a file or sequence of paths to files that can be read by
46
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
47
            DICOM files. If :py:attr:`tensor` is given, the data in
48
            :py:attr:`path` will not be read.
49
            If a sequence of paths is given, data
50
            will be concatenated on the channel dimension so spatial
51
            dimensions must match.
52
        type: Type of image, such as :attr:`torchio.INTENSITY` or
53
            :attr:`torchio.LABEL`. This will be used by the transforms to
54
            decide whether to apply an operation, or which interpolation to use
55
            when resampling. For example, `preprocessing`_ and `augmentation`_
56
            intensity transforms will only be applied to images with type
57
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
58
            all types, and nearest neighbor interpolation is always used to
59
            resample images with type :attr:`torchio.LABEL`.
60
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
61
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
62
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
63
            :py:class:`torch.Tensor` or NumPy array with dimensions
64
            :math:`(C, W, H, D)`.
65
        affine: If :attr:`path` is not given, :attr:`affine` must be a
66
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
67
            identity matrix.
68
        check_nans: If ``True``, issues a warning if NaNs are found
69
            in the image. If ``False``, images will not be checked for the
70
            presence of NaNs.
71
        **kwargs: Items that will be added to the image dictionary, e.g.
72
            acquisition parameters.
73
74
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
75
    when needed.
76
77
    Example:
78
        >>> import torchio
79
        >>> image = torchio.ScalarImage('t1.nii.gz')  # subclass of Image
80
        >>> image  # not loaded yet
81
        ScalarImage(path: t1.nii.gz; type: intensity)
82
        >>> times_two = 2 * image.data  # data is loaded and cached here
83
        >>> image
84
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
85
        >>> image.save('doubled_image.nii.gz')
86
87
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
88
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
89
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
90
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
91
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
92
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
93
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
94
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
95
    """
96
    def __init__(
97
            self,
98
            path: Union[TypePath, Sequence[TypePath], None] = None,
99
            type: str = None,
100
            tensor: Optional[TypeData] = None,
101
            affine: Optional[TypeData] = None,
102
            check_nans: bool = False,  # removed by ITK by default
103
            channels_last: bool = False,
104
            **kwargs: Dict[str, Any],
105
            ):
106
        self.check_nans = check_nans
107
        self.channels_last = channels_last
108
109
        if type is None:
110
            warnings.warn(
111
                'Not specifying the image type is deprecated and will be'
112
                ' mandatory in the future. You can probably use ScalarImage or'
113
                ' LabelMap instead'
114
            )
115
            type = INTENSITY
116
117
        if path is None and tensor is None:
118
            raise ValueError('A value for path or tensor must be given')
119
        self._loaded = False
120
121
        tensor = self._parse_tensor(tensor)
122
        affine = self._parse_affine(affine)
123
        if tensor is not None:
124
            self[DATA] = tensor
125
            self[AFFINE] = affine
126
            self._loaded = True
127
        for key in PROTECTED_KEYS:
128
            if key in kwargs:
129
                message = f'Key "{key}" is reserved. Use a different one'
130
                raise ValueError(message)
131
132
        super().__init__(**kwargs)
133
        self.path = self._parse_path(path)
134
135
        self[PATH] = '' if self.path is None else str(self.path)
136
        self[STEM] = '' if self.path is None else get_stem(self.path)
137
        self[TYPE] = type
138
139
    def __repr__(self):
140
        properties = []
141
        if self._loaded:
142
            properties.extend([
143
                f'shape: {self.shape}',
144
                f'spacing: {self.get_spacing_string()}',
145
                f'orientation: {"".join(self.orientation)}+',
146
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
147
            ])
148
        else:
149
            properties.append(f'path: "{self.path}"')
150
        properties.append(f'type: {self.type}')
151
        properties = '; '.join(properties)
152
        string = f'{self.__class__.__name__}({properties})'
153
        return string
154
155
    def __getitem__(self, item):
156
        if item in (DATA, AFFINE):
157
            if item not in self:
158
                self.load()
159
        return super().__getitem__(item)
160
161
    def __array__(self):
162
        return self[DATA].numpy()
163
164
    def __copy__(self):
165
        kwargs = dict(
166
            tensor=self.data,
167
            affine=self.affine,
168
            type=self.type,
169
            path=self.path,
170
        )
171
        for key, value in self.items():
172
            if key in PROTECTED_KEYS: continue
173
            kwargs[key] = value  # should I copy? deepcopy?
174
        return self.__class__(**kwargs)
175
176
    @property
177
    def data(self) -> torch.Tensor:
178
        """Tensor data. Same as :py:class:`Image.tensor`."""
179
        return self[DATA]
180
181
    @property
182
    def tensor(self) -> torch.Tensor:
183
        """Tensor data. Same as :py:class:`Image.data`."""
184
        return self.data
185
186
    @property
187
    def affine(self) -> np.ndarray:
188
        """Affine matrix to transform voxel indices into world coordinates."""
189
        return self[AFFINE]
190
191
    @property
192
    def type(self) -> str:
193
        return self[TYPE]
194
195
    @property
196
    def shape(self) -> Tuple[int, int, int, int]:
197
        """Tensor shape as :math:`(C, W, H, D)`."""
198
        return tuple(self.data.shape)
199
200
    @property
201
    def spatial_shape(self) -> TypeTripletInt:
202
        """Tensor spatial shape as :math:`(W, H, D)`."""
203
        return self.shape[1:]
204
205
    def check_is_2d(self) -> None:
206
        if not self.is_2d():
207
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
208
            raise RuntimeError(message)
209
210
    @property
211
    def height(self) -> int:
212
        """Image height, if 2D."""
213
        self.check_is_2d()
214
        return self.spatial_shape[1]
215
216
    @property
217
    def width(self) -> int:
218
        """Image width, if 2D."""
219
        self.check_is_2d()
220
        return self.spatial_shape[0]
221
222
    @property
223
    def orientation(self) -> Tuple[str, str, str]:
224
        """Orientation codes."""
225
        return nib.aff2axcodes(self.affine)
226
227
    @property
228
    def spacing(self) -> Tuple[float, float, float]:
229
        """Voxel spacing in mm."""
230
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
231
        return tuple(spacing)
232
233
    @property
234
    def memory(self) -> float:
235
        """Number of Bytes that the tensor takes in the RAM."""
236
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
237
238
    @property
239
    def bounds(self) -> np.ndarray:
240
        """Position of centers of voxels in smallest and largest coordinates."""
241
        ini = 0, 0, 0
242
        fin = np.array(self.spatial_shape) - 1
243
        point_ini = nib.affines.apply_affine(self.affine, ini)
244
        point_fin = nib.affines.apply_affine(self.affine, fin)
245
        return np.array((point_ini, point_fin))
246
247
    def axis_name_to_index(self, axis: str):
248
        """Convert an axis name to an axis index.
249
250
        Args:
251
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
252
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
253
                versions and first letters are also valid, as only the first
254
                letter will be used.
255
256
        .. note:: If you are working with animals, you should probably use
257
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
258
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
259
            respectively.
260
261
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
262
            ``'Left'`` and ``'Right'``.
263
        """
264
        # Top and bottom are used for the vertical 2D axis as the use of
265
        # Height vs Horizontal might be ambiguous
266
267
        if not isinstance(axis, str):
268
            raise ValueError('Axis must be a string')
269
        axis = axis[0].upper()
270
271
        # Generally, TorchIO tensors are (C, W, H, D)
272
        if axis in 'TB':  # Top, Bottom
273
            return -2
274
        else:
275
            try:
276
                index = self.orientation.index(axis)
277
            except ValueError:
278
                index = self.orientation.index(self.flip_axis(axis))
279
            # Return negative indices so that it does not matter whether we
280
            # refer to spatial dimensions or not
281
            index = -3 + index
282
            return index
283
284
    # flake8: noqa: E701
285
    @staticmethod
286
    def flip_axis(axis):
287
        if axis == 'R': return 'L'
288
        elif axis == 'L': return 'R'
289
        elif axis == 'A': return 'P'
290
        elif axis == 'P': return 'A'
291
        elif axis == 'I': return 'S'
292
        elif axis == 'S': return 'I'
293
        else:
294
            values = ', '.join('LRPAISTB')
295
            message = f'Axis not understood. Please use one of: {values}'
296
            raise ValueError(message)
297
298
    def get_spacing_string(self):
299
        strings = [f'{n:.2f}' for n in self.spacing]
300
        string = f'({", ".join(strings)})'
301
        return string
302
303
    def get_bounds(self):
304
        """Get image bounds in mm."""
305
        first_index = 3 * (-0.5,)
306
        last_index = np.array(self.spatial_shape) - 0.5
307
        first_point = nib.affines.apply_affine(self.affine, first_index)
308
        last_point = nib.affines.apply_affine(self.affine, last_index)
309
        array = np.array((first_point, last_point))
310
        bounds_x, bounds_y, bounds_z = array.T.tolist()
311
        return bounds_x, bounds_y, bounds_z
312
313
    @staticmethod
314
    def _parse_single_path(
315
            path: TypePath
316
            ) -> Path:
317
        try:
318
            path = Path(path).expanduser()
319
        except TypeError:
320
            message = (
321
                f'Expected type str or Path but found {path} with '
322
                f'{type(path)} instead'
323
            )
324
            raise TypeError(message)
325
        except RuntimeError:
326
            message = (
327
                f'Conversion to path not possible for variable: {path}'
328
            )
329
            raise RuntimeError(message)
330
331
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
332
            raise FileNotFoundError(f'File not found: "{path}"')
333
        return path
334
335
    def _parse_path(
336
            self,
337
            path: Union[TypePath, Sequence[TypePath]]
338
            ) -> Union[Path, List[Path]]:
339
        if path is None:
340
            return None
341
        if isinstance(path, (str, Path)):
342
            return self._parse_single_path(path)
343
        else:
344
            return [self._parse_single_path(p) for p in path]
345
346
    def _parse_tensor(self, tensor: TypeData) -> torch.Tensor:
347
        if tensor is None:
348
            return None
349
        if isinstance(tensor, np.ndarray):
350
            tensor = torch.from_numpy(tensor.astype(np.float32))
351
        elif isinstance(tensor, torch.Tensor):
352
            tensor = tensor.float()
353
        if tensor.ndim != 4:
354
            raise ValueError('Input tensor must be 4D')
355
        if self.check_nans and torch.isnan(tensor).any():
356
            warnings.warn(f'NaNs found in tensor')
357
        return tensor
358
359
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
360
        return ensure_4d(tensor)
361
362
    @staticmethod
363
    def _parse_affine(affine: np.ndarray) -> np.ndarray:
364
        if affine is None:
365
            return np.eye(4)
366
        if not isinstance(affine, np.ndarray):
367
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
368
        if affine.shape != (4, 4):
369
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
370
        return affine
371
372
    def load(self) -> None:
373
        r"""Load the image from disk.
374
375
        Returns:
376
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
377
            :math:`4 \times 4` affine matrix to convert voxel indices to world
378
            coordinates.
379
        """
380
        if self._loaded:
381
            return
382
        paths = self.path if isinstance(self.path, list) else [self.path]
383
        tensor, affine = self.read_and_check(paths[0])
384
        tensors = [tensor]
385
        for path in paths[1:]:
386
            new_tensor, new_affine = self.read_and_check(path)
387
            if not np.array_equal(affine, new_affine):
388
                message = (
389
                    'Files have different affine matrices.'
390
                    f'\nMatrix of {paths[0]}:'
391
                    f'\n{affine}'
392
                    f'\nMatrix of {path}:'
393
                    f'\n{new_affine}'
394
                )
395
                warnings.warn(message, RuntimeWarning)
396
            if not tensor.shape[1:] == new_tensor.shape[1:]:
397
                message = (
398
                    f'Files shape do not match, found {tensor.shape}'
399
                    f'and {new_tensor.shape}'
400
                )
401
                RuntimeError(message)
402
            tensors.append(new_tensor)
403
        tensor = torch.cat(tensors)
404
        self[DATA] = tensor
405
        self[AFFINE] = affine
406
        self._loaded = True
407
408
    def read_and_check(self, path):
409
        tensor, affine = read_image(path)
410
        tensor = self.parse_tensor_shape(tensor)
411
        if self.channels_last:
412
            tensor = tensor.permute(3, 0, 1, 2)
413
        if self.check_nans and torch.isnan(tensor).any():
414
            warnings.warn(f'NaNs found in file "{path}"')
415
        return tensor, affine
416
417
    def save(self, path: TypePath, squeeze: bool = True):
418
        """Save image to disk.
419
420
        Args:
421
            path: String or instance of :py:class:`pathlib.Path`.
422
            squeeze: If ``True``, the singleton dimensions will be removed
423
                before saving.
424
        """
425
        write_image(
426
            self[DATA],
427
            self[AFFINE],
428
            path,
429
            squeeze=squeeze,
430
        )
431
432
    def is_2d(self) -> bool:
433
        return self.shape[-1] == 1
434
435
    def numpy(self) -> np.ndarray:
436
        """Get a NumPy array containing the image data."""
437
        return np.asarray(self)
438
439
    def as_sitk(self, **kwargs) -> sitk.Image:
440
        """Get the image as an instance of :py:class:`sitk.Image`."""
441
        return nib_to_sitk(self[DATA], self[AFFINE], **kwargs)
442
443
    def as_pil(self):
444
        """Get the image as an instance of :py:class:`PIL.Image`."""
445
        self.check_is_2d()
446
        return ImagePIL.open(self.path)
447
448
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
449
        """Get image center in RAS+ or LPS+ coordinates.
450
451
        Args:
452
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
453
                the first dimension grows towards the left, etc. Otherwise, the
454
                coordinates will be in RAS+ orientation.
455
        """
456
        size = np.array(self.spatial_shape)
457
        center_index = (size - 1) / 2
458
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
459
        if lps:
460
            return (-r, -a, s)
461
        else:
462
            return (r, a, s)
463
464
    def set_check_nans(self, check_nans: bool):
465
        self.check_nans = check_nans
466
467
    def plot(self, **kwargs) -> None:
468
        """Plot image."""
469
        if self.is_2d():
470
            self.as_pil().show()
471
        else:
472
            from ..visualization import plot_volume  # avoid circular import
473
            plot_volume(self, **kwargs)
474
475
476
class ScalarImage(Image):
477
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
478
479
    Example:
480
        >>> import torch
481
        >>> import torchio
482
        >>> # Loading from a file
483
        >>> t1_image = torchio.ScalarImage('t1.nii.gz')
484
        >>> dmri = torchio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
485
        >>> image = torchio.ScalarImage('safe_image.nrrd', check_nans=False)
486
        >>> data, affine = image.data, image.affine
487
        >>> affine.shape
488
        (4, 4)
489
        >>> image.data is image[torchio.DATA]
490
        True
491
        >>> image.data is image.tensor
492
        True
493
        >>> type(image.data)
494
        torch.Tensor
495
496
    See :py:class:`~torchio.Image` for more information.
497
498
    Raises:
499
        ValueError: A :py:attr:`type` is used for instantiation.
500
    """
501
    def __init__(self, *args, **kwargs):
502
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
503
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
504
        kwargs.update({'type': INTENSITY})
505
        super().__init__(*args, **kwargs)
506
507
508
class LabelMap(Image):
509
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
510
511
    Example:
512
        >>> import torch
513
        >>> import torchio
514
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
515
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
516
        >>> tpm = torchio.LabelMap(                     # loading from files
517
        ...     'gray_matter.nii.gz',
518
        ...     'white_matter.nii.gz',
519
        ...     'csf.nii.gz',
520
        ... )
521
522
    See :py:class:`~torchio.data.image.Image` for more information.
523
524
    Raises:
525
        ValueError: If a value for :py:attr:`type` is given.
526
    """
527
    def __init__(self, *args, **kwargs):
528
        if 'type' in kwargs and kwargs['type'] != LABEL:
529
            raise ValueError('Type of LabelMap is always torchio.LABEL')
530
        kwargs.update({'type': LABEL})
531
        super().__init__(*args, **kwargs)
532