Passed
Push — master ( e921a7...82a6cd )
by Fernando
01:10
created

torchio.data.image.Image.set_data()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 2
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
from PIL import Image as ImagePIL
11
from deprecated import deprecated
12
13
from ..utils import get_stem
14
from ..typing import TypeData, TypePath, TypeTripletInt, TypeTripletFloat
15
from ..constants import DATA, TYPE, AFFINE, PATH, STEM, INTENSITY, LABEL
16
from .io import (
17
    ensure_4d,
18
    read_image,
19
    write_image,
20
    nib_to_sitk,
21
    check_uint_to_int,
22
    get_rotation_and_spacing_from_affine,
23
)
24
25
26
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
27
TypeBound = Tuple[float, float]
28
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
29
30
deprecation_message = (
31
    'Setting the image data with the property setter is deprecated. Use the'
32
    ' set_data() method instead'
33
)
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file or sequence of paths to files that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :attr:`tensor` is given, the data in
47
            :attr:`path` will not be read.
48
            If a sequence of paths is given, data
49
            will be concatenated on the channel dimension so spatial
50
            dimensions must match.
51
        type: Type of image, such as :attr:`torchio.INTENSITY` or
52
            :attr:`torchio.LABEL`. This will be used by the transforms to
53
            decide whether to apply an operation, or which interpolation to use
54
            when resampling. For example, `preprocessing`_ and `augmentation`_
55
            intensity transforms will only be applied to images with type
56
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
57
            all types, and nearest neighbor interpolation is always used to
58
            resample images with type :attr:`torchio.LABEL`.
59
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
60
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
61
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
62
            :class:`torch.Tensor` or NumPy array with dimensions
63
            :math:`(C, W, H, D)`.
64
        affine: If :attr:`path` is not given, :attr:`affine` must be a
65
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
66
            identity matrix.
67
        check_nans: If ``True``, issues a warning if NaNs are found
68
            in the image. If ``False``, images will not be checked for the
69
            presence of NaNs.
70
        **kwargs: Items that will be added to the image dictionary, e.g.
71
            acquisition parameters.
72
73
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
74
    when needed.
75
76
    Example:
77
        >>> import torchio as tio
78
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
79
        >>> image  # not loaded yet
80
        ScalarImage(path: t1.nii.gz; type: intensity)
81
        >>> times_two = 2 * image.data  # data is loaded and cached here
82
        >>> image
83
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
84
        >>> image.save('doubled_image.nii.gz')
85
86
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
87
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
88
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
89
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
90
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
91
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
92
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
93
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
94
    """
95
    def __init__(
96
            self,
97
            path: Union[TypePath, Sequence[TypePath], None] = None,
98
            type: str = None,
99
            tensor: Optional[TypeData] = None,
100
            affine: Optional[TypeData] = None,
101
            check_nans: bool = False,  # removed by ITK by default
102
            channels_last: bool = False,
103
            **kwargs: Dict[str, Any],
104
            ):
105
        self.check_nans = check_nans
106
        self.channels_last = channels_last
107
108
        if type is None:
109
            warnings.warn(
110
                'Not specifying the image type is deprecated and will be'
111
                ' mandatory in the future. You can probably use tio.ScalarImage'
112
                ' or tio.LabelMap instead', DeprecationWarning,
113
            )
114
            type = INTENSITY
115
116
        if path is None and tensor is None:
117
            raise ValueError('A value for path or tensor must be given')
118
        self._loaded = False
119
120
        tensor = self._parse_tensor(tensor)
121
        affine = self._parse_affine(affine)
122
        if tensor is not None:
123
            self.set_data(tensor)
124
            self.affine = affine
125
            self._loaded = True
126
        for key in PROTECTED_KEYS:
127
            if key in kwargs:
128
                message = f'Key "{key}" is reserved. Use a different one'
129
                raise ValueError(message)
130
131
        super().__init__(**kwargs)
132
        self.path = self._parse_path(path)
133
134
        self[PATH] = '' if self.path is None else str(self.path)
135
        self[STEM] = '' if self.path is None else get_stem(self.path)
136
        self[TYPE] = type
137
138
    def __repr__(self):
139
        properties = []
140
        if self._loaded:
141
            properties.extend([
142
                f'shape: {self.shape}',
143
                f'spacing: {self.get_spacing_string()}',
144
                f'orientation: {"".join(self.orientation)}+',
145
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
146
            ])
147
        else:
148
            properties.append(f'path: "{self.path}"')
149
        if self._loaded:
150
            properties.append(f'dtype: {self.data.type()}')
151
        properties = '; '.join(properties)
152
        string = f'{self.__class__.__name__}({properties})'
153
        return string
154
155
    def __getitem__(self, item):
156
        if item in (DATA, AFFINE):
157
            if item not in self:
158
                self.load()
159
        return super().__getitem__(item)
160
161
    def __array__(self):
162
        return self.data.numpy()
163
164
    def __copy__(self):
165
        kwargs = dict(
166
            tensor=self.data,
167
            affine=self.affine,
168
            type=self.type,
169
            path=self.path,
170
        )
171
        for key, value in self.items():
172
            if key in PROTECTED_KEYS: continue
173
            kwargs[key] = value  # should I copy? deepcopy?
174
        return self.__class__(**kwargs)
175
176
    @property
177
    def data(self) -> torch.Tensor:
178
        """Tensor data. Same as :class:`Image.tensor`."""
179
        return self[DATA]
180
181
    @data.setter
182
    @deprecated(version='0.18.16', reason=deprecation_message)
183
    def data(self, tensor: TypeData):
184
        self.set_data(tensor)
185
186
    def set_data(self, tensor: TypeData):
187
        """Store a 4D tensor in the :attr:`data` key and attribute.
188
189
        Args:
190
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
191
        """
192
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
193
194
    @property
195
    def tensor(self) -> torch.Tensor:
196
        """Tensor data. Same as :class:`Image.data`."""
197
        return self.data
198
199
    @property
200
    def affine(self) -> np.ndarray:
201
        """Affine matrix to transform voxel indices into world coordinates."""
202
        return self[AFFINE]
203
204
    @affine.setter
205
    def affine(self, matrix):
206
        self[AFFINE] = self._parse_affine(matrix)
207
208
    @property
209
    def type(self) -> str:
210
        return self[TYPE]
211
212
    @property
213
    def shape(self) -> Tuple[int, int, int, int]:
214
        """Tensor shape as :math:`(C, W, H, D)`."""
215
        return tuple(self.data.shape)
216
217
    @property
218
    def spatial_shape(self) -> TypeTripletInt:
219
        """Tensor spatial shape as :math:`(W, H, D)`."""
220
        return self.shape[1:]
221
222
    def check_is_2d(self) -> None:
223
        if not self.is_2d():
224
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
225
            raise RuntimeError(message)
226
227
    @property
228
    def height(self) -> int:
229
        """Image height, if 2D."""
230
        self.check_is_2d()
231
        return self.spatial_shape[1]
232
233
    @property
234
    def width(self) -> int:
235
        """Image width, if 2D."""
236
        self.check_is_2d()
237
        return self.spatial_shape[0]
238
239
    @property
240
    def orientation(self) -> Tuple[str, str, str]:
241
        """Orientation codes."""
242
        return nib.aff2axcodes(self.affine)
243
244
    @property
245
    def spacing(self) -> Tuple[float, float, float]:
246
        """Voxel spacing in mm."""
247
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
248
        return tuple(spacing)
249
250
    @property
251
    def memory(self) -> float:
252
        """Number of Bytes that the tensor takes in the RAM."""
253
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
254
255
    @property
256
    def bounds(self) -> np.ndarray:
257
        """Position of centers of voxels in smallest and largest coordinates."""
258
        ini = 0, 0, 0
259
        fin = np.array(self.spatial_shape) - 1
260
        point_ini = nib.affines.apply_affine(self.affine, ini)
261
        point_fin = nib.affines.apply_affine(self.affine, fin)
262
        return np.array((point_ini, point_fin))
263
264
    def axis_name_to_index(self, axis: str) -> int:
265
        """Convert an axis name to an axis index.
266
267
        Args:
268
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
269
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
270
                versions and first letters are also valid, as only the first
271
                letter will be used.
272
273
        .. note:: If you are working with animals, you should probably use
274
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
275
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
276
            respectively.
277
278
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
279
            ``'Left'`` and ``'Right'``.
280
        """
281
        # Top and bottom are used for the vertical 2D axis as the use of
282
        # Height vs Horizontal might be ambiguous
283
284
        if not isinstance(axis, str):
285
            raise ValueError('Axis must be a string')
286
        axis = axis[0].upper()
287
288
        # Generally, TorchIO tensors are (C, W, H, D)
289
        if axis in 'TB':  # Top, Bottom
290
            return -2
291
        else:
292
            try:
293
                index = self.orientation.index(axis)
294
            except ValueError:
295
                index = self.orientation.index(self.flip_axis(axis))
296
            # Return negative indices so that it does not matter whether we
297
            # refer to spatial dimensions or not
298
            index = -3 + index
299
            return index
300
301
    # flake8: noqa: E701
302
    @staticmethod
303
    def flip_axis(axis: str) -> str:
304
        if axis == 'R': flipped_axis = 'L'
305
        elif axis == 'L': flipped_axis = 'R'
306
        elif axis == 'A': flipped_axis = 'P'
307
        elif axis == 'P': flipped_axis = 'A'
308
        elif axis == 'I': flipped_axis = 'S'
309
        elif axis == 'S': flipped_axis = 'I'
310
        elif axis == 'T': flipped_axis = 'B'
311
        elif axis == 'B': flipped_axis = 'T'
312
        else:
313
            values = ', '.join('LRPAISTB')
314
            message = f'Axis not understood. Please use one of: {values}'
315
            raise ValueError(message)
316
        return flipped_axis
317
318
    def get_spacing_string(self) -> str:
319
        strings = [f'{n:.2f}' for n in self.spacing]
320
        string = f'({", ".join(strings)})'
321
        return string
322
323
    def get_bounds(self) -> TypeBounds:
324
        """Get minimum and maximum world coordinates occupied by the image."""
325
        first_index = 3 * (-0.5,)
326
        last_index = np.array(self.spatial_shape) - 0.5
327
        first_point = nib.affines.apply_affine(self.affine, first_index)
328
        last_point = nib.affines.apply_affine(self.affine, last_index)
329
        array = np.array((first_point, last_point))
330
        bounds_x, bounds_y, bounds_z = array.T.tolist()
331
        return bounds_x, bounds_y, bounds_z
332
333
    @staticmethod
334
    def _parse_single_path(
335
            path: TypePath
336
            ) -> Path:
337
        try:
338
            path = Path(path).expanduser()
339
        except TypeError:
340
            message = (
341
                f'Expected type str or Path but found {path} with type'
342
                f' {type(path)} instead'
343
            )
344
            raise TypeError(message)
345
        except RuntimeError:
346
            message = (
347
                f'Conversion to path not possible for variable: {path}'
348
            )
349
            raise RuntimeError(message)
350
351
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
352
            raise FileNotFoundError(f'File not found: "{path}"')
353
        return path
354
355
    def _parse_path(
356
            self,
357
            path: Union[TypePath, Sequence[TypePath]]
358
            ) -> Optional[Union[Path, List[Path]]]:
359
        if path is None:
360
            return None
361
        if isinstance(path, (str, Path)):
362
            return self._parse_single_path(path)
363
        else:
364
            return [self._parse_single_path(p) for p in path]
365
366
    def _parse_tensor(
367
            self,
368
            tensor: TypeData,
369
            none_ok: bool = True,
370
            ) -> torch.Tensor:
371
        if tensor is None:
372
            if none_ok:
373
                return None
374
            else:
375
                raise RuntimeError('Input tensor cannot be None')
376
        if isinstance(tensor, np.ndarray):
377
            tensor = check_uint_to_int(tensor)
378
            tensor = torch.from_numpy(tensor)
379
        elif not isinstance(tensor, torch.Tensor):
380
            message = 'Input tensor must be a PyTorch tensor or NumPy array'
381
            raise TypeError(message)
382
        if tensor.ndim != 4:
383
            raise ValueError('Input tensor must be 4D')
384
        if tensor.dtype == torch.bool:
385
            tensor = tensor.to(torch.uint8)
386
        if self.check_nans and torch.isnan(tensor).any():
387
            warnings.warn(f'NaNs found in tensor', RuntimeWarning)
388
        return tensor
389
390
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
391
        return ensure_4d(tensor)
392
393
    @staticmethod
394
    def _parse_affine(affine: TypeData) -> np.ndarray:
395
        if affine is None:
396
            return np.eye(4)
397
        if isinstance(affine, torch.Tensor):
398
            affine = affine.numpy()
399
        if not isinstance(affine, np.ndarray):
400
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
401
        if affine.shape != (4, 4):
402
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
403
        return affine
404
405
    def load(self) -> None:
406
        r"""Load the image from disk.
407
408
        Returns:
409
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
410
            :math:`4 \times 4` affine matrix to convert voxel indices to world
411
            coordinates.
412
        """
413
        if self._loaded:
414
            return
415
        paths = self.path if isinstance(self.path, list) else [self.path]
416
        tensor, affine = self.read_and_check(paths[0])
417
        tensors = [tensor]
418
        for path in paths[1:]:
419
            new_tensor, new_affine = self.read_and_check(path)
420
            if not np.array_equal(affine, new_affine):
421
                message = (
422
                    'Files have different affine matrices.'
423
                    f'\nMatrix of {paths[0]}:'
424
                    f'\n{affine}'
425
                    f'\nMatrix of {path}:'
426
                    f'\n{new_affine}'
427
                )
428
                warnings.warn(message, RuntimeWarning)
429
            if not tensor.shape[1:] == new_tensor.shape[1:]:
430
                message = (
431
                    f'Files shape do not match, found {tensor.shape}'
432
                    f'and {new_tensor.shape}'
433
                )
434
                RuntimeError(message)
435
            tensors.append(new_tensor)
436
        tensor = torch.cat(tensors)
437
        self.set_data(tensor)
438
        self.affine = affine
439
        self._loaded = True
440
441
    def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
442
        tensor, affine = read_image(path)
443
        tensor = self.parse_tensor_shape(tensor)
444
        if self.channels_last:
445
            tensor = tensor.permute(3, 0, 1, 2)
446
        if self.check_nans and torch.isnan(tensor).any():
447
            warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
448
        return tensor, affine
449
450
    def save(self, path: TypePath, squeeze: bool = True) -> None:
451
        """Save image to disk.
452
453
        Args:
454
            path: String or instance of :class:`pathlib.Path`.
455
            squeeze: If ``True``, singleton dimensions will be removed
456
                before saving.
457
        """
458
        write_image(
459
            self.data,
460
            self.affine,
461
            path,
462
            squeeze=squeeze,
463
        )
464
465
    def is_2d(self) -> bool:
466
        return self.shape[-1] == 1
467
468
    def numpy(self) -> np.ndarray:
469
        """Get a NumPy array containing the image data."""
470
        return np.asarray(self)
471
472
    def as_sitk(self, **kwargs) -> sitk.Image:
473
        """Get the image as an instance of :class:`sitk.Image`."""
474
        return nib_to_sitk(self.data, self.affine, **kwargs)
475
476
    def as_pil(self) -> ImagePIL:
477
        """Get the image as an instance of :class:`PIL.Image`.
478
479
        .. note:: Values will be clamped to 0-255 and cast to uint8.
480
        """
481
        self.check_is_2d()
482
        tensor = self.data
483
        if len(tensor) == 1:
484
            tensor = torch.cat(3 * [tensor])
485
        if len(tensor) != 3:
486
            raise RuntimeError('The image must have 1 or 3 channels')
487
        tensor = tensor.permute(3, 1, 2, 0)[0]
488
        array = tensor.clamp(0, 255).numpy()
489
        return ImagePIL.fromarray(array.astype(np.uint8))
490
491
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
492
        """Get image center in RAS+ or LPS+ coordinates.
493
494
        Args:
495
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
496
                the first dimension grows towards the left, etc. Otherwise, the
497
                coordinates will be in RAS+ orientation.
498
        """
499
        size = np.array(self.spatial_shape)
500
        center_index = (size - 1) / 2
501
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
502
        if lps:
503
            return (-r, -a, s)
504
        else:
505
            return (r, a, s)
506
507
    def set_check_nans(self, check_nans: bool) -> None:
508
        self.check_nans = check_nans
509
510
    def plot(self, **kwargs) -> None:
511
        """Plot image."""
512
        if self.is_2d():
513
            self.as_pil().show()
514
        else:
515
            from ..visualization import plot_volume  # avoid circular import
516
            plot_volume(self, **kwargs)
517
518
519
class ScalarImage(Image):
520
    """Image whose pixel values represent scalars.
521
522
    Example:
523
        >>> import torch
524
        >>> import torchio as tio
525
        >>> # Loading from a file
526
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
527
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
528
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
529
        >>> data, affine = image.data, image.affine
530
        >>> affine.shape
531
        (4, 4)
532
        >>> image.data is image[tio.DATA]
533
        True
534
        >>> image.data is image.tensor
535
        True
536
        >>> type(image.data)
537
        torch.Tensor
538
539
    See :class:`~torchio.Image` for more information.
540
    """
541
    def __init__(self, *args, **kwargs):
542
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
543
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
544
        kwargs.update({'type': INTENSITY})
545
        super().__init__(*args, **kwargs)
546
547
548
class LabelMap(Image):
549
    """Image whose pixel values represent categorical labels.
550
551
    Intensity transforms are not applied to these images.
552
553
    Example:
554
        >>> import torch
555
        >>> import torchio as tio
556
        >>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5)
557
        >>> labels = tio.LabelMap('t1_seg.nii.gz')  # loading from a file
558
        >>> tpm = tio.LabelMap(                     # loading from files
559
        ...     'gray_matter.nii.gz',
560
        ...     'white_matter.nii.gz',
561
        ...     'csf.nii.gz',
562
        ... )
563
564
    See :class:`~torchio.Image` for more information.
565
    """
566
    def __init__(self, *args, **kwargs):
567
        if 'type' in kwargs and kwargs['type'] != LABEL:
568
            raise ValueError('Type of LabelMap is always torchio.LABEL')
569
        kwargs.update({'type': LABEL})
570
        super().__init__(*args, **kwargs)
571