Passed
Pull Request — master (#332)
by Fernando
01:14
created

torchio.data.image.Image.as_sitk()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file or sequence of paths to files that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read.
48
            If a sequence of paths is given, data
49
            will be concatenated on the channel dimension so spatial
50
            dimensions must match.
51
        type: Type of image, such as :attr:`torchio.INTENSITY` or
52
            :attr:`torchio.LABEL`. This will be used by the transforms to
53
            decide whether to apply an operation, or which interpolation to use
54
            when resampling. For example, `preprocessing`_ and `augmentation`_
55
            intensity transforms will only be applied to images with type
56
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
57
            all types, and nearest neighbor interpolation is always used to
58
            resample images with type :attr:`torchio.LABEL`.
59
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
60
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
61
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
62
            :py:class:`torch.Tensor` or NumPy array with dimensions
63
            :math:`(C, W, H, D)`.
64
        affine: If :attr:`path` is not given, :attr:`affine` must be a
65
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
66
            identity matrix.
67
        check_nans: If ``True``, issues a warning if NaNs are found
68
            in the image. If ``False``, images will not be checked for the
69
            presence of NaNs.
70
        **kwargs: Items that will be added to the image dictionary, e.g.
71
            acquisition parameters.
72
73
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
74
    when needed.
75
76
    Example:
77
        >>> import torchio
78
        >>> image = torchio.ScalarImage('t1.nii.gz')  # subclass of Image
79
        >>> image  # not loaded yet
80
        ScalarImage(path: t1.nii.gz; type: intensity)
81
        >>> times_two = 2 * image.data  # data is loaded and cached here
82
        >>> image
83
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
84
        >>> image.save('doubled_image.nii.gz')
85
86
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
87
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
88
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
89
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
90
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
91
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
92
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
93
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
94
    """
95
    def __init__(
96
            self,
97
            path: Union[TypePath, Sequence[TypePath], None] = None,
98
            type: str = None,
99
            tensor: Optional[TypeData] = None,
100
            affine: Optional[TypeData] = None,
101
            check_nans: bool = False,  # removed by ITK by default
102
            channels_last: bool = False,
103
            **kwargs: Dict[str, Any],
104
            ):
105
        self.check_nans = check_nans
106
        self.channels_last = channels_last
107
108
        if type is None:
109
            warnings.warn(
110
                'Not specifying the image type is deprecated and will be'
111
                ' mandatory in the future. You can probably use ScalarImage or'
112
                ' LabelMap instead'
113
            )
114
            type = INTENSITY
115
116
        if path is None and tensor is None:
117
            raise ValueError('A value for path or tensor must be given')
118
        self._loaded = False
119
120
        tensor = self._parse_tensor(tensor)
121
        affine = self._parse_affine(affine)
122
        if tensor is not None:
123
            self[DATA] = tensor
124
            self[AFFINE] = affine
125
            self._loaded = True
126
        for key in PROTECTED_KEYS:
127
            if key in kwargs:
128
                message = f'Key "{key}" is reserved. Use a different one'
129
                raise ValueError(message)
130
131
        super().__init__(**kwargs)
132
        self.path = self._parse_path(path)
133
134
        self[PATH] = '' if self.path is None else str(self.path)
135
        self[STEM] = '' if self.path is None else get_stem(self.path)
136
        self[TYPE] = type
137
138
    def __repr__(self):
139
        properties = []
140
        if self._loaded:
141
            properties.extend([
142
                f'shape: {self.shape}',
143
                f'spacing: {self.get_spacing_string()}',
144
                f'orientation: {"".join(self.orientation)}+',
145
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
146
            ])
147
        else:
148
            properties.append(f'path: "{self.path}"')
149
        properties.append(f'type: {self.type}')
150
        properties = '; '.join(properties)
151
        string = f'{self.__class__.__name__}({properties})'
152
        return string
153
154
    def __getitem__(self, item):
155
        if item in (DATA, AFFINE):
156
            if item not in self:
157
                self.load()
158
        return super().__getitem__(item)
159
160
    def __array__(self):
161
        return self[DATA].numpy()
162
163
    def __copy__(self):
164
        kwargs = dict(
165
            tensor=self.data,
166
            affine=self.affine,
167
            type=self.type,
168
            path=self.path,
169
        )
170
        for key, value in self.items():
171
            if key in PROTECTED_KEYS: continue
172
            kwargs[key] = value  # should I copy? deepcopy?
173
        return self.__class__(**kwargs)
174
175
    @property
176
    def data(self) -> torch.Tensor:
177
        """Tensor data. Same as :py:class:`Image.tensor`."""
178
        return self[DATA]
179
180
    @property
181
    def tensor(self) -> torch.Tensor:
182
        """Tensor data. Same as :py:class:`Image.data`."""
183
        return self.data
184
185
    @property
186
    def affine(self) -> np.ndarray:
187
        """Affine matrix to transform voxel indices into world coordinates."""
188
        return self[AFFINE]
189
190
    @property
191
    def type(self) -> str:
192
        return self[TYPE]
193
194
    @property
195
    def shape(self) -> Tuple[int, int, int, int]:
196
        """Tensor shape as :math:`(C, W, H, D)`."""
197
        return tuple(self.data.shape)
198
199
    @property
200
    def spatial_shape(self) -> TypeTripletInt:
201
        """Tensor spatial shape as :math:`(W, H, D)`."""
202
        return self.shape[1:]
203
204
    def check_is_2d(self) -> None:
205
        if not self.is_2d():
206
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
207
            raise RuntimeError(message)
208
209
    @property
210
    def height(self) -> int:
211
        """Image height, if 2D."""
212
        self.check_is_2d()
213
        return self.spatial_shape[1]
214
215
    @property
216
    def width(self) -> int:
217
        """Image width, if 2D."""
218
        self.check_is_2d()
219
        return self.spatial_shape[0]
220
221
    @property
222
    def orientation(self) -> Tuple[str, str, str]:
223
        """Orientation codes."""
224
        return nib.aff2axcodes(self.affine)
225
226
    @property
227
    def spacing(self) -> Tuple[float, float, float]:
228
        """Voxel spacing in mm."""
229
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
230
        return tuple(spacing)
231
232
    @property
233
    def memory(self) -> float:
234
        """Number of Bytes that the tensor takes in the RAM."""
235
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
236
237
    @property
238
    def bounds(self) -> np.ndarray:
239
        """Position of centers of voxels in smallest and largest coordinates."""
240
        ini = 0, 0, 0
241
        fin = np.array(self.spatial_shape) - 1
242
        point_ini = nib.affines.apply_affine(self.affine, ini)
243
        point_fin = nib.affines.apply_affine(self.affine, fin)
244
        return np.array((point_ini, point_fin))
245
246
    def axis_name_to_index(self, axis: str):
247
        """Convert an axis name to an axis index.
248
249
        Args:
250
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
251
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
252
                versions and first letters are also valid, as only the first
253
                letter will be used.
254
255
        .. note:: If you are working with animals, you should probably use
256
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
257
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
258
            respectively.
259
260
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
261
            ``'Left'`` and ``'Right'``.
262
        """
263
        # Top and bottom are used for the vertical 2D axis as the use of
264
        # Height vs Horizontal might be ambiguous
265
266
        if not isinstance(axis, str):
267
            raise ValueError('Axis must be a string')
268
        axis = axis[0].upper()
269
270
        # Generally, TorchIO tensors are (C, W, H, D)
271
        if axis in 'TB':  # Top, Bottom
272
            return -2
273
        else:
274
            try:
275
                index = self.orientation.index(axis)
276
            except ValueError:
277
                index = self.orientation.index(self.flip_axis(axis))
278
            # Return negative indices so that it does not matter whether we
279
            # refer to spatial dimensions or not
280
            index = -3 + index
281
            return index
282
283
    # flake8: noqa: E701
284
    @staticmethod
285
    def flip_axis(axis):
286
        if axis == 'R': return 'L'
287
        elif axis == 'L': return 'R'
288
        elif axis == 'A': return 'P'
289
        elif axis == 'P': return 'A'
290
        elif axis == 'I': return 'S'
291
        elif axis == 'S': return 'I'
292
        else:
293
            values = ', '.join('LRPAISTB')
294
            message = f'Axis not understood. Please use one of: {values}'
295
            raise ValueError(message)
296
297
    def get_spacing_string(self):
298
        strings = [f'{n:.2f}' for n in self.spacing]
299
        string = f'({", ".join(strings)})'
300
        return string
301
302
    def get_bounds(self):
303
        """Get image bounds in mm."""
304
        first_index = 3 * (-0.5,)
305
        last_index = np.array(self.spatial_shape) - 0.5
306
        first_point = nib.affines.apply_affine(self.affine, first_index)
307
        last_point = nib.affines.apply_affine(self.affine, last_index)
308
        array = np.array((first_point, last_point))
309
        bounds_x, bounds_y, bounds_z = array.T.tolist()
310
        return bounds_x, bounds_y, bounds_z
311
312
    @staticmethod
313
    def _parse_single_path(
314
            path: TypePath
315
            ) -> Path:
316
        try:
317
            path = Path(path).expanduser()
318
        except TypeError:
319
            message = (
320
                f'Expected type str or Path but found {path} with '
321
                f'{type(path)} instead'
322
            )
323
            raise TypeError(message)
324
        except RuntimeError:
325
            message = (
326
                f'Conversion to path not possible for variable: {path}'
327
            )
328
            raise RuntimeError(message)
329
330
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
331
            raise FileNotFoundError(f'File not found: "{path}"')
332
        return path
333
334
    def _parse_path(
335
            self,
336
            path: Union[TypePath, Sequence[TypePath]]
337
            ) -> Union[Path, List[Path]]:
338
        if path is None:
339
            return None
340
        if isinstance(path, (str, Path)):
341
            return self._parse_single_path(path)
342
        else:
343
            return [self._parse_single_path(p) for p in path]
344
345
    def _parse_tensor(self, tensor: TypeData) -> torch.Tensor:
346
        if tensor is None:
347
            return None
348
        if isinstance(tensor, np.ndarray):
349
            tensor = torch.from_numpy(tensor.astype(np.float32))
350
        elif isinstance(tensor, torch.Tensor):
351
            tensor = tensor.float()
352
        if tensor.ndim != 4:
353
            raise ValueError('Input tensor must be 4D')
354
        if self.check_nans and torch.isnan(tensor).any():
355
            warnings.warn(f'NaNs found in tensor')
356
        return tensor
357
358
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
359
        return ensure_4d(tensor)
360
361
    @staticmethod
362
    def _parse_affine(affine: np.ndarray) -> np.ndarray:
363
        if affine is None:
364
            return np.eye(4)
365
        if not isinstance(affine, np.ndarray):
366
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
367
        if affine.shape != (4, 4):
368
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
369
        return affine
370
371
    def load(self) -> None:
372
        r"""Load the image from disk.
373
374
        Returns:
375
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
376
            :math:`4 \times 4` affine matrix to convert voxel indices to world
377
            coordinates.
378
        """
379
        if self._loaded:
380
            return
381
        paths = self.path if isinstance(self.path, list) else [self.path]
382
        tensor, affine = self.read_and_check(paths[0])
383
        tensors = [tensor]
384
        for path in paths[1:]:
385
            new_tensor, new_affine = self.read_and_check(path)
386
            if not np.array_equal(affine, new_affine):
387
                message = (
388
                    'Files have different affine matrices.'
389
                    f'\nMatrix of {paths[0]}:'
390
                    f'\n{affine}'
391
                    f'\nMatrix of {path}:'
392
                    f'\n{new_affine}'
393
                )
394
                warnings.warn(message, RuntimeWarning)
395
            if not tensor.shape[1:] == new_tensor.shape[1:]:
396
                message = (
397
                    f'Files shape do not match, found {tensor.shape}'
398
                    f'and {new_tensor.shape}'
399
                )
400
                RuntimeError(message)
401
            tensors.append(new_tensor)
402
        tensor = torch.cat(tensors)
403
        self[DATA] = tensor
404
        self[AFFINE] = affine
405
        self._loaded = True
406
407
    def read_and_check(self, path):
408
        tensor, affine = read_image(path)
409
        tensor = self.parse_tensor_shape(tensor)
410
        if self.channels_last:
411
            tensor = tensor.permute(3, 0, 1, 2)
412
        if self.check_nans and torch.isnan(tensor).any():
413
            warnings.warn(f'NaNs found in file "{path}"')
414
        return tensor, affine
415
416
    def save(self, path: TypePath, squeeze: bool = True):
417
        """Save image to disk.
418
419
        Args:
420
            path: String or instance of :py:class:`pathlib.Path`.
421
            squeeze: If ``True``, the singleton dimensions will be removed
422
                before saving.
423
        """
424
        write_image(
425
            self[DATA],
426
            self[AFFINE],
427
            path,
428
            squeeze=squeeze,
429
        )
430
431
    def is_2d(self) -> bool:
432
        return self.shape[-1] == 1
433
434
    def numpy(self) -> np.ndarray:
435
        """Get a NumPy array containing the image data."""
436
        return np.asarray(self)
437
438
    def as_sitk(self, **kwargs) -> sitk.Image:
439
        """Get the image as an instance of :py:class:`sitk.Image`."""
440
        return nib_to_sitk(self[DATA], self[AFFINE], **kwargs)
441
442
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
443
        """Get image center in RAS+ or LPS+ coordinates.
444
445
        Args:
446
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
447
                the first dimension grows towards the left, etc. Otherwise, the
448
                coordinates will be in RAS+ orientation.
449
        """
450
        size = np.array(self.spatial_shape)
451
        center_index = (size - 1) / 2
452
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
453
        if lps:
454
            return (-r, -a, s)
455
        else:
456
            return (r, a, s)
457
458
    def set_check_nans(self, check_nans: bool):
459
        self.check_nans = check_nans
460
461
    def plot(self, **kwargs) -> None:
462
        """Plot central slices of the image."""
463
        from ..visualization import plot_volume  # avoid circular import
464
        plot_volume(self, **kwargs)
465
466
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
467
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
468
        new_affine = self.affine.copy()
469
        new_affine[:3, 3] = new_origin
470
        i0, j0, k0 = index_ini
471
        i1, j1, k1 = index_fin
472
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
473
        kwargs = dict(
474
            tensor=patch,
475
            affine=new_affine,
476
            type=self.type,
477
            path=self.path,
478
        )
479
        for key, value in self.items():
480
            if key in PROTECTED_KEYS: continue
481
            kwargs[key] = value  # should I copy? deepcopy?
482
        return self.__class__(**kwargs)
483
484
485
class ScalarImage(Image):
486
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
487
488
    Example:
489
        >>> import torch
490
        >>> import torchio
491
        >>> # Loading from a file
492
        >>> t1_image = torchio.ScalarImage('t1.nii.gz')
493
        >>> dmri = torchio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
494
        >>> image = torchio.ScalarImage('safe_image.nrrd', check_nans=False)
495
        >>> data, affine = image.data, image.affine
496
        >>> affine.shape
497
        (4, 4)
498
        >>> image.data is image[torchio.DATA]
499
        True
500
        >>> image.data is image.tensor
501
        True
502
        >>> type(image.data)
503
        torch.Tensor
504
505
    See :py:class:`~torchio.Image` for more information.
506
507
    Raises:
508
        ValueError: A :py:attr:`type` is used for instantiation.
509
    """
510
    def __init__(self, *args, **kwargs):
511
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
512
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
513
        kwargs.update({'type': INTENSITY})
514
        super().__init__(*args, **kwargs)
515
516
517
class LabelMap(Image):
518
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
519
520
    Example:
521
        >>> import torch
522
        >>> import torchio
523
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
524
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
525
        >>> tpm = torchio.LabelMap(                     # loading from files
526
        ...     'gray_matter.nii.gz',
527
        ...     'white_matter.nii.gz',
528
        ...     'csf.nii.gz',
529
        ... )
530
531
    See :py:class:`~torchio.data.image.Image` for more information.
532
533
    Raises:
534
        ValueError: If a value for :py:attr:`type` is given.
535
    """
536
    def __init__(self, *args, **kwargs):
537
        if 'type' in kwargs and kwargs['type'] != LABEL:
538
            raise ValueError('Type of LabelMap is always torchio.LABEL')
539
        kwargs.update({'type': LABEL})
540
        super().__init__(*args, **kwargs)
541