Passed
Push — master ( 47015e...349d77 )
by Fernando
01:20
created

torchio.data.image.Image.save()   A

Complexity

Conditions 1

Size

Total Lines 16
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 7
nop 4
dl 0
loc 16
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file or sequence of paths to files that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read. The data is expected to be 2D or
48
            3D, and may have multiple channels (see :attr:`num_spatial_dims` and
49
            :attr:`channels_last`). If a sequence of paths is given, data
50
            will be concatenated on the channel dimension so spatial
51
            dimensions must match. If the read tensor is not 4D,
52
            TorchIO will try to guess the dimensions meanings.
53
            If 2D, the shape will be interpreted as
54
            :math:`(H, W)`. If 3D, the number of spatial dimensions should be
55
            determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims`
56
            is not given and the shape is 3 along the first or last dimensions,
57
            it will be interpreted as a multichannel 2D image. Otherwise, it
58
            be interpreted as a 3D image with a single channel.
59
        type: Type of image, such as :attr:`torchio.INTENSITY` or
60
            :attr:`torchio.LABEL`. This will be used by the transforms to
61
            decide whether to apply an operation, or which interpolation to use
62
            when resampling. For example, `preprocessing`_ and `augmentation`_
63
            intensity transforms will only be applied to images with type
64
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
65
            all types, and nearest neighbor interpolation is always used to
66
            resample images with type :attr:`torchio.LABEL`.
67
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
68
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
69
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
70
            :py:class:`torch.Tensor` or NumPy array with dimensions
71
            :math:`(C, H, W, D)`.
72
        affine: If :attr:`path` is not given, :attr:`affine` must be a
73
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
74
            identity matrix.
75
        check_nans: If ``True``, issues a warning if NaNs are found
76
            in the image. If ``False``, images will not be checked for the
77
            presence of NaNs.
78
        num_spatial_dims: If ``2`` and the input tensor has 3 dimensions, it
79
            will be interpreted as a multichannel 2D image. If ``3`` and the
80
            input has 3 dimensions, it will be interpreted as a
81
            single-channel 3D volume.
82
        channels_last: If ``True``, the last dimension of the input will be
83
            interpreted as the channels. Defaults to ``True`` if :attr:`path` is
84
            given and ``False`` otherwise.
85
        **kwargs: Items that will be added to the image dictionary, e.g.
86
            acquisition parameters.
87
88
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
89
    when needed.
90
91
    Example:
92
        >>> import torchio
93
        >>> image = torchio.ScalarImage('t1.nii.gz')  # subclass of Image
94
        >>> image  # not loaded yet
95
        ScalarImage(path: t1.nii.gz; type: intensity)
96
        >>> times_two = 2 * image.data  # data is loaded and cached here
97
        >>> image
98
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
99
        >>> image.save('doubled_image.nii.gz')
100
101
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
102
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
103
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
104
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
105
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
106
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
107
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
108
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
109
    """
110
    def __init__(
111
            self,
112
            path: Union[TypePath, Sequence[TypePath], None] = None,
113
            type: str = None,
114
            tensor: Optional[TypeData] = None,
115
            affine: Optional[TypeData] = None,
116
            check_nans: bool = True,
117
            num_spatial_dims: Optional[int] = None,
118
            channels_last: Optional[bool] = None,
119
            **kwargs: Dict[str, Any],
120
            ):
121
        self.check_nans = check_nans
122
        self.num_spatial_dims = num_spatial_dims
123
124
        if type is None:
125
            warnings.warn(
126
                'Not specifying the image type is deprecated and will be'
127
                ' mandatory in the future. You can probably use ScalarImage or'
128
                ' LabelMap instead'
129
            )
130
            type = INTENSITY
131
132
        if path is None and tensor is None:
133
            raise ValueError('A value for path or tensor must be given')
134
        self._loaded = False
135
136
        if tensor is not None and channels_last is not None:
137
            message = (
138
                'channels_last cannot be specified is tensor is not None.'
139
                ' The tensor shape must be (C, H, W, D)'
140
            )
141
            raise ValueError(message)
142
143
        tensor = self.parse_tensor(tensor)
144
        affine = self.parse_affine(affine)
145
        if tensor is not None:
146
            self[DATA] = tensor
147
            self[AFFINE] = affine
148
            self._loaded = True
149
        for key in PROTECTED_KEYS:
150
            if key in kwargs:
151
                message = f'Key "{key}" is reserved. Use a different one'
152
                raise ValueError(message)
153
154
        super().__init__(**kwargs)
155
        self.path = self._parse_path(path)
156
157
        # Number of channels are typically stored in the last dimensions in disk
158
        # But if a tensor is given, the channels should be in the first dim
159
        if channels_last is None:
160
            channels_last = self.guess_channels_last(self.path)
161
        self.channels_last = channels_last
162
        self[PATH] = '' if self.path is None else str(self.path)
163
        self[STEM] = '' if self.path is None else get_stem(self.path)
164
        self[TYPE] = type
165
166
    def __repr__(self):
167
        properties = []
168
        if self._loaded:
169
            properties.extend([
170
                f'shape: {self.shape}',
171
                f'spacing: {self.get_spacing_string()}',
172
                f'orientation: {"".join(self.orientation)}+',
173
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
174
            ])
175
        else:
176
            properties.append(f'path: "{self.path}"')
177
        properties.append(f'type: {self.type}')
178
        properties = '; '.join(properties)
179
        string = f'{self.__class__.__name__}({properties})'
180
        return string
181
182
    def __getitem__(self, item):
183
        if item in (DATA, AFFINE):
184
            if item not in self:
185
                self.load()
186
        return super().__getitem__(item)
187
188
    def __array__(self):
189
        return self[DATA].numpy()
190
191
    def __copy__(self):
192
        kwargs = dict(
193
            tensor=self.data,
194
            affine=self.affine,
195
            type=self.type,
196
            path=self.path,
197
        )
198
        for key, value in self.items():
199
            if key in PROTECTED_KEYS: continue
200
            kwargs[key] = value  # should I copy? deepcopy?
201
        return self.__class__(**kwargs)
202
203
    @property
204
    def data(self):
205
        return self[DATA]
206
207
    @property
208
    def tensor(self):
209
        return self.data
210
211
    @property
212
    def affine(self):
213
        return self[AFFINE]
214
215
    @property
216
    def type(self):
217
        return self[TYPE]
218
219
    @property
220
    def shape(self) -> Tuple[int, int, int, int]:
221
        return tuple(self.data.shape)
222
223
    @property
224
    def spatial_shape(self) -> TypeTripletInt:
225
        return self.shape[1:]
226
227
    def check_is_2d(self):
228
        if not self.is_2d():
229
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
230
            raise RuntimeError(message)
231
232
    @property
233
    def height(self) -> int:
234
        self.check_is_2d()
235
        return self.spatial_shape[0]
236
237
    @property
238
    def width(self) -> int:
239
        self.check_is_2d()
240
        return self.spatial_shape[1]
241
242
    @property
243
    def orientation(self):
244
        return nib.aff2axcodes(self.affine)
245
246
    @property
247
    def spacing(self):
248
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
249
        return tuple(spacing)
250
251
    @property
252
    def memory(self):
253
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
254
255
    def axis_name_to_index(self, axis: str):
256
        """Convert an axis name to an axis index.
257
258
        Args:
259
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
260
            ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case versions
261
            and first letters are also valid, as only the first letter will be
262
            used.
263
264
        .. note:: If you are working with animals, you should probably use
265
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
266
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
267
            respectively.
268
        """
269
        if not isinstance(axis, str):
270
            raise ValueError('Axis must be a string')
271
        axis = axis[0].upper()
272
273
        # Generally, TorchIO tensors are (C, H, W, D)
274
        if axis == 'H':
275
            return 1
276
        elif axis == 'W':
277
            return 2
278
        else:
279
            try:
280
                index = self.orientation.index(axis)
281
            except ValueError:
282
                index = self.orientation.index(self.flip_axis(axis))
283
            # Return negative indices so that it does not matter whether we
284
            # refer to spatial dimensions or not
285
            index = -4 + index
286
            return index
287
288
    # flake8: noqa: E701
289
    @staticmethod
290
    def flip_axis(axis):
291
        if axis == 'R': return 'L'
292
        elif axis == 'L': return 'R'
293
        elif axis == 'A': return 'P'
294
        elif axis == 'P': return 'A'
295
        elif axis == 'I': return 'S'
296
        elif axis == 'S': return 'I'
297
        else:
298
            message = (
299
                f'Axis not understood. Please use a value in {tuple("LRAPIS")}'
300
            )
301
            raise ValueError(message)
302
303
    def get_spacing_string(self):
304
        strings = [f'{n:.2f}' for n in self.spacing]
305
        string = f'({", ".join(strings)})'
306
        return string
307
308
    def get_bounds(self):
309
        """Get image bounds in mm."""
310
        first_index = 3 * (-0.5,)
311
        last_index = np.array(self.spatial_shape) - 0.5
312
        first_point = nib.affines.apply_affine(self.affine, first_index)
313
        last_point = nib.affines.apply_affine(self.affine, last_index)
314
        array = np.array((first_point, last_point))
315
        bounds_x, bounds_y, bounds_z = array.T.tolist()
316
        return bounds_x, bounds_y, bounds_z
317
318
    @staticmethod
319
    def _parse_single_path(
320
            path: TypePath
321
            ) -> Path:
322
        try:
323
            path = Path(path).expanduser()
324
        except TypeError:
325
            message = (
326
                f'Expected type str or Path but found {path} with '
327
                f'{type(path)} instead'
328
            )
329
            raise TypeError(message)
330
        except RuntimeError:
331
            message = (
332
                f'Conversion to path not possible for variable: {path}'
333
            )
334
            raise RuntimeError(message)
335
336
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
337
            raise FileNotFoundError(f'File not found: {path}')
338
        return path
339
340
    def _parse_path(
341
            self,
342
            path: Union[TypePath, Sequence[TypePath]]
343
            ) -> Union[Path, List[Path]]:
344
        if path is None:
345
            return None
346
        if isinstance(path, (str, Path)):
347
            return self._parse_single_path(path)
348
        else:
349
            return [self._parse_single_path(p) for p in path]
350
351
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
352
        if tensor is None:
353
            return None
354
        if isinstance(tensor, np.ndarray):
355
            tensor = torch.from_numpy(tensor.astype(np.float32))
356
        elif isinstance(tensor, torch.Tensor):
357
            tensor = tensor.float()
358
        if tensor.ndim != 4:
359
            raise ValueError('Input tensor must be 4D')
360
        if self.check_nans and torch.isnan(tensor).any():
361
            warnings.warn(f'NaNs found in tensor')
362
        return tensor
363
364
    def parse_tensor_shape(
365
            self,
366
            tensor: torch.Tensor,
367
            num_spatial_dims: int,
368
            channels_last: Optional[bool],
369
            ) -> torch.Tensor:
370
        return ensure_4d(tensor, channels_last, num_spatial_dims)
371
372
    @staticmethod
373
    def parse_affine(affine: np.ndarray) -> np.ndarray:
374
        if affine is None:
375
            return np.eye(4)
376
        if not isinstance(affine, np.ndarray):
377
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
378
        if affine.shape != (4, 4):
379
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
380
        return affine
381
382
    def load(self) -> None:
383
        r"""Load the image from disk.
384
385
        Returns:
386
            Tuple containing a 4D tensor of size :math:`(C, H, W, D)` and a 2D
387
            :math:`4 \times 4` affine matrix to convert voxel indices to world
388
            coordinates.
389
        """
390
        if self._loaded:
391
            return
392
393
        paths = self.path if isinstance(self.path, list) else [self.path]
394
        tensor, affine = read_image(paths[0])
395
        tensor = self.parse_tensor_shape(
396
            tensor,
397
            self.num_spatial_dims,
398
            self.channels_last,
399
        )
400
401
        if self.check_nans and torch.isnan(tensor).any():
402
            warnings.warn(f'NaNs found in file "{paths[0]}"')
403
404
        tensors = [tensor]
405
        for path in paths[1:]:
406
            new_tensor, new_affine = read_image(path)
407
            new_tensor = self.parse_tensor_shape(
408
                new_tensor,
409
                self.num_spatial_dims,
410
                self.channels_last,
411
            )
412
413
            if self.check_nans and torch.isnan(tensor).any():
414
                warnings.warn(f'NaNs found in file "{path}"')
415
416
            if not np.array_equal(affine, new_affine):
417
                message = (
418
                    'Files have different affine matrices.'
419
                    f'\nMatrix of {paths[0]}:'
420
                    f'\n{affine}'
421
                    f'\nMatrix of {path}:'
422
                    f'\n{new_affine}'
423
                )
424
                warnings.warn(message, RuntimeWarning)
425
426
            if not tensor.shape[1:] == new_tensor.shape[1:]:
427
                message = (
428
                    f'Files shape do not match, found {tensor.shape}'
429
                    f'and {new_tensor.shape}'
430
                )
431
                RuntimeError(message)
432
433
            tensors.append(new_tensor)
434
435
        tensor = torch.cat(tensors)
436
437
        self[DATA] = tensor
438
        self[AFFINE] = affine
439
        self._loaded = True
440
441
    def save(self, path, squeeze=True, channels_last=True):
442
        """Save image to disk.
443
444
        Args:
445
            path: String or instance of :py:class:`pathlib.Path`.
446
            squeeze: If ``True``, the singleton dimensions will be removed
447
                before saving.
448
            channels_last: If ``True``, the channels will be saved in the last
449
                dimension.
450
        """
451
        write_image(
452
            self[DATA],
453
            self[AFFINE],
454
            path,
455
            squeeze=squeeze,
456
            channels_last=channels_last,
457
        )
458
459
    def is_2d(self) -> bool:
460
        return self.shape[-1] == 1
461
462
    def numpy(self) -> np.ndarray:
463
        """Get a NumPy array containing the image data."""
464
        return np.asarray(self)
465
466
    def as_sitk(self, **kwargs) -> sitk.Image:
467
        """Get the image as an instance of :py:class:`sitk.Image`."""
468
        return nib_to_sitk(self[DATA], self[AFFINE], **kwargs)
469
470
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
471
        """Get image center in RAS+ or LPS+ coordinates.
472
473
        Args:
474
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
475
                the first dimension grows towards the left, etc. Otherwise, the
476
                coordinates will be in RAS+ orientation.
477
        """
478
        size = np.array(self.spatial_shape)
479
        center_index = (size - 1) / 2
480
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
481
        if lps:
482
            return (-r, -a, s)
483
        else:
484
            return (r, a, s)
485
486
    def set_check_nans(self, check_nans: bool):
487
        self.check_nans = check_nans
488
489
    def guess_channels_last(self, path: Union[Path, List[Path]]) -> bool:
490
        # https://github.com/fepegar/torchio/issues/273
491
        if path is None:
492
            return
493
        if isinstance(path, list):  # assume all images are similar
494
            path = path[0]
495
        return not path.is_dir()
496
497
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
498
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
499
        new_affine = self.affine.copy()
500
        new_affine[:3, 3] = new_origin
501
        i0, j0, k0 = index_ini
502
        i1, j1, k1 = index_fin
503
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
504
        kwargs = dict(
505
            tensor=patch,
506
            affine=new_affine,
507
            type=self.type,
508
            path=self.path,
509
        )
510
        for key, value in self.items():
511
            if key in PROTECTED_KEYS: continue
512
            kwargs[key] = value  # should I copy? deepcopy?
513
        return self.__class__(**kwargs)
514
515
516
class ScalarImage(Image):
517
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
518
519
    Example:
520
        >>> import torch
521
        >>> import torchio
522
        >>> # Loading from a file
523
        >>> t1_image = torchio.ScalarImage('t1.nii.gz')
524
        >>> dmri = torchio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
525
        >>> image = torchio.ScalarImage('safe_image.nrrd', check_nans=False)
526
        >>> data, affine = image.data, image.affine
527
        >>> affine.shape
528
        (4, 4)
529
        >>> image.data is image[torchio.DATA]
530
        True
531
        >>> image.data is image.tensor
532
        True
533
        >>> type(image.data)
534
        torch.Tensor
535
536
    See :py:class:`~torchio.Image` for more information.
537
538
    Raises:
539
        ValueError: A :py:attr:`type` is used for instantiation.
540
    """
541
    def __init__(self, *args, **kwargs):
542
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
543
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
544
        kwargs.update({'type': INTENSITY})
545
        super().__init__(*args, **kwargs)
546
547
548
class LabelMap(Image):
549
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
550
551
    Example:
552
        >>> import torch
553
        >>> import torchio
554
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
555
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
556
        >>> tpm = torchio.LabelMap(                     # loading from files
557
        ...     'gray_matter.nii.gz',
558
        ...     'white_matter.nii.gz',
559
        ...     'csf.nii.gz',
560
        ... )
561
562
    See :py:class:`~torchio.data.image.Image` for more information.
563
564
    Raises:
565
        ValueError: If a value for :py:attr:`type` is given.
566
    """
567
    def __init__(self, *args, **kwargs):
568
        if 'type' in kwargs and kwargs['type'] != LABEL:
569
            raise ValueError('Type of LabelMap is always torchio.LABEL')
570
        kwargs.update({'type': LABEL})
571
        super().__init__(*args, **kwargs)
572