Passed
Push — master ( e59e2c...7d4e6f )
by Fernando
01:21
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 44
Code Lines 35

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 35
dl 0
loc 44
rs 6.6666
c 0
b 0
f 0
cc 9
nop 9

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List, Callable
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
from deprecated import deprecated
11
12
from ..utils import get_stem
13
from ..typing import TypeData, TypePath, TypeTripletInt, TypeTripletFloat
14
from ..constants import DATA, TYPE, AFFINE, PATH, STEM, INTENSITY, LABEL
15
from .io import (
16
    ensure_4d,
17
    read_image,
18
    write_image,
19
    nib_to_sitk,
20
    check_uint_to_int,
21
    get_rotation_and_spacing_from_affine,
22
)
23
24
25
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
26
TypeBound = Tuple[float, float]
27
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
28
29
deprecation_message = (
30
    'Setting the image data with the property setter is deprecated. Use the'
31
    ' set_data() method instead'
32
)
33
34
35
class Image(dict):
36
    r"""TorchIO image.
37
38
    For information about medical image orientation, check out `NiBabel docs`_,
39
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
40
    `SimpleITK docs`_.
41
42
    Args:
43
        path: Path to a file or sequence of paths to files that can be read by
44
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
45
            DICOM files. If :attr:`tensor` is given, the data in
46
            :attr:`path` will not be read.
47
            If a sequence of paths is given, data
48
            will be concatenated on the channel dimension so spatial
49
            dimensions must match.
50
        type: Type of image, such as :attr:`torchio.INTENSITY` or
51
            :attr:`torchio.LABEL`. This will be used by the transforms to
52
            decide whether to apply an operation, or which interpolation to use
53
            when resampling. For example, `preprocessing`_ and `augmentation`_
54
            intensity transforms will only be applied to images with type
55
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
56
            all types, and nearest neighbor interpolation is always used to
57
            resample images with type :attr:`torchio.LABEL`.
58
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
59
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
60
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
61
            :class:`torch.Tensor` or NumPy array with dimensions
62
            :math:`(C, W, H, D)`.
63
        affine: If :attr:`path` is not given, :attr:`affine` must be a
64
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
65
            identity matrix.
66
        check_nans: If ``True``, issues a warning if NaNs are found
67
            in the image. If ``False``, images will not be checked for the
68
            presence of NaNs.
69
        reader: Callable object that takes a path and returns a 4D tensor and a
70
            2D, :math:`4 \times 4` affine matrix. This can be used if your data
71
            is saved in a custom format, such as ``.npy`` (see example below).
72
            If the affine matrix is ``None``, an identity matrix will be used.
73
        **kwargs: Items that will be added to the image dictionary, e.g.
74
            acquisition parameters.
75
76
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
77
    when needed.
78
79
    Example:
80
        >>> import torchio as tio
81
        >>> import numpy as np
82
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
83
        >>> image  # not loaded yet
84
        ScalarImage(path: t1.nii.gz; type: intensity)
85
        >>> times_two = 2 * image.data  # data is loaded and cached here
86
        >>> image
87
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
88
        >>> image.save('doubled_image.nii.gz')
89
        >>> numpy_reader = lambda path: np.load(path), np.eye(4)
90
        >>> image = tio.ScalarImage('t1.npy', reader=numpy_reader)
91
92
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
93
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
94
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
95
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
96
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
97
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
98
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
99
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
100
    """
101
    def __init__(
102
            self,
103
            path: Union[TypePath, Sequence[TypePath], None] = None,
104
            type: str = None,
105
            tensor: Optional[TypeData] = None,
106
            affine: Optional[TypeData] = None,
107
            check_nans: bool = False,  # removed by ITK by default
108
            channels_last: bool = False,
109
            reader: Callable = read_image,
110
            **kwargs: Dict[str, Any],
111
            ):
112
        self.check_nans = check_nans
113
        self.channels_last = channels_last
114
        self.reader = reader
115
116
        if type is None:
117
            warnings.warn(
118
                'Not specifying the image type is deprecated and will be'
119
                ' mandatory in the future. You can probably use tio.ScalarImage'
120
                ' or tio.LabelMap instead',
121
            )
122
            type = INTENSITY
123
124
        if path is None and tensor is None:
125
            raise ValueError('A value for path or tensor must be given')
126
        self._loaded = False
127
128
        tensor = self._parse_tensor(tensor)
129
        affine = self._parse_affine(affine)
130
        if tensor is not None:
131
            self.set_data(tensor)
132
            self.affine = affine
133
            self._loaded = True
134
        for key in PROTECTED_KEYS:
135
            if key in kwargs:
136
                message = f'Key "{key}" is reserved. Use a different one'
137
                raise ValueError(message)
138
139
        super().__init__(**kwargs)
140
        self.path = self._parse_path(path)
141
142
        self[PATH] = '' if self.path is None else str(self.path)
143
        self[STEM] = '' if self.path is None else get_stem(self.path)
144
        self[TYPE] = type
145
146
    def __repr__(self):
147
        properties = []
148
        if self._loaded:
149
            properties.extend([
150
                f'shape: {self.shape}',
151
                f'spacing: {self.get_spacing_string()}',
152
                f'orientation: {"".join(self.orientation)}+',
153
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
154
            ])
155
        else:
156
            properties.append(f'path: "{self.path}"')
157
        if self._loaded:
158
            properties.append(f'dtype: {self.data.type()}')
159
        properties = '; '.join(properties)
160
        string = f'{self.__class__.__name__}({properties})'
161
        return string
162
163
    def __getitem__(self, item):
164
        if item in (DATA, AFFINE):
165
            if item not in self:
166
                self.load()
167
        return super().__getitem__(item)
168
169
    def __array__(self):
170
        return self.data.numpy()
171
172
    def __copy__(self):
173
        kwargs = dict(
174
            tensor=self.data,
175
            affine=self.affine,
176
            type=self.type,
177
            path=self.path,
178
        )
179
        for key, value in self.items():
180
            if key in PROTECTED_KEYS: continue
181
            kwargs[key] = value  # should I copy? deepcopy?
182
        return self.__class__(**kwargs)
183
184
    @property
185
    def data(self) -> torch.Tensor:
186
        """Tensor data. Same as :class:`Image.tensor`."""
187
        return self[DATA]
188
189
    @data.setter
190
    @deprecated(version='0.18.16', reason=deprecation_message)
191
    def data(self, tensor: TypeData):
192
        self.set_data(tensor)
193
194
    def set_data(self, tensor: TypeData):
195
        """Store a 4D tensor in the :attr:`data` key and attribute.
196
197
        Args:
198
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
199
        """
200
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
201
202
    @property
203
    def tensor(self) -> torch.Tensor:
204
        """Tensor data. Same as :class:`Image.data`."""
205
        return self.data
206
207
    @property
208
    def affine(self) -> np.ndarray:
209
        """Affine matrix to transform voxel indices into world coordinates."""
210
        return self[AFFINE]
211
212
    @affine.setter
213
    def affine(self, matrix):
214
        self[AFFINE] = self._parse_affine(matrix)
215
216
    @property
217
    def type(self) -> str:
218
        return self[TYPE]
219
220
    @property
221
    def shape(self) -> Tuple[int, int, int, int]:
222
        """Tensor shape as :math:`(C, W, H, D)`."""
223
        return tuple(self.data.shape)
224
225
    @property
226
    def spatial_shape(self) -> TypeTripletInt:
227
        """Tensor spatial shape as :math:`(W, H, D)`."""
228
        return self.shape[1:]
229
230
    def check_is_2d(self) -> None:
231
        if not self.is_2d():
232
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
233
            raise RuntimeError(message)
234
235
    @property
236
    def height(self) -> int:
237
        """Image height, if 2D."""
238
        self.check_is_2d()
239
        return self.spatial_shape[1]
240
241
    @property
242
    def width(self) -> int:
243
        """Image width, if 2D."""
244
        self.check_is_2d()
245
        return self.spatial_shape[0]
246
247
    @property
248
    def orientation(self) -> Tuple[str, str, str]:
249
        """Orientation codes."""
250
        return nib.aff2axcodes(self.affine)
251
252
    @property
253
    def spacing(self) -> Tuple[float, float, float]:
254
        """Voxel spacing in mm."""
255
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
256
        return tuple(spacing)
257
258
    @property
259
    def itemsize(self):
260
        """Element size of the data type."""
261
        return self.data.element_size()
262
263
    @property
264
    def memory(self) -> float:
265
        """Number of Bytes that the tensor takes in the RAM."""
266
        return np.prod(self.shape) * self.itemsize
267
268
    @property
269
    def bounds(self) -> np.ndarray:
270
        """Position of centers of voxels in smallest and largest coordinates."""
271
        ini = 0, 0, 0
272
        fin = np.array(self.spatial_shape) - 1
273
        point_ini = nib.affines.apply_affine(self.affine, ini)
274
        point_fin = nib.affines.apply_affine(self.affine, fin)
275
        return np.array((point_ini, point_fin))
276
277
    @property
278
    def num_channels(self) -> int:
279
        """Get the number of channels in the associated 4D tensor."""
280
        return len(self.data)
281
282
    def axis_name_to_index(self, axis: str) -> int:
283
        """Convert an axis name to an axis index.
284
285
        Args:
286
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
287
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
288
                versions and first letters are also valid, as only the first
289
                letter will be used.
290
291
        .. note:: If you are working with animals, you should probably use
292
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
293
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
294
            respectively.
295
296
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
297
            ``'Left'`` and ``'Right'``.
298
        """
299
        # Top and bottom are used for the vertical 2D axis as the use of
300
        # Height vs Horizontal might be ambiguous
301
302
        if not isinstance(axis, str):
303
            raise ValueError('Axis must be a string')
304
        axis = axis[0].upper()
305
306
        # Generally, TorchIO tensors are (C, W, H, D)
307
        if axis in 'TB':  # Top, Bottom
308
            return -2
309
        else:
310
            try:
311
                index = self.orientation.index(axis)
312
            except ValueError:
313
                index = self.orientation.index(self.flip_axis(axis))
314
            # Return negative indices so that it does not matter whether we
315
            # refer to spatial dimensions or not
316
            index = -3 + index
317
            return index
318
319
    # flake8: noqa: E701
320
    @staticmethod
321
    def flip_axis(axis: str) -> str:
322
        if axis == 'R': flipped_axis = 'L'
323
        elif axis == 'L': flipped_axis = 'R'
324
        elif axis == 'A': flipped_axis = 'P'
325
        elif axis == 'P': flipped_axis = 'A'
326
        elif axis == 'I': flipped_axis = 'S'
327
        elif axis == 'S': flipped_axis = 'I'
328
        elif axis == 'T': flipped_axis = 'B'
329
        elif axis == 'B': flipped_axis = 'T'
330
        else:
331
            values = ', '.join('LRPAISTB')
332
            message = f'Axis not understood. Please use one of: {values}'
333
            raise ValueError(message)
334
        return flipped_axis
335
336
    def get_spacing_string(self) -> str:
337
        strings = [f'{n:.2f}' for n in self.spacing]
338
        string = f'({", ".join(strings)})'
339
        return string
340
341
    def get_bounds(self) -> TypeBounds:
342
        """Get minimum and maximum world coordinates occupied by the image."""
343
        first_index = 3 * (-0.5,)
344
        last_index = np.array(self.spatial_shape) - 0.5
345
        first_point = nib.affines.apply_affine(self.affine, first_index)
346
        last_point = nib.affines.apply_affine(self.affine, last_index)
347
        array = np.array((first_point, last_point))
348
        bounds_x, bounds_y, bounds_z = array.T.tolist()
349
        return bounds_x, bounds_y, bounds_z
350
351
    @staticmethod
352
    def _parse_single_path(
353
            path: TypePath
354
            ) -> Path:
355
        try:
356
            path = Path(path).expanduser()
357
        except TypeError:
358
            message = (
359
                f'Expected type str or Path but found {path} with type'
360
                f' {type(path)} instead'
361
            )
362
            raise TypeError(message)
363
        except RuntimeError:
364
            message = (
365
                f'Conversion to path not possible for variable: {path}'
366
            )
367
            raise RuntimeError(message)
368
369
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
370
            raise FileNotFoundError(f'File not found: "{path}"')
371
        return path
372
373
    def _parse_path(
374
            self,
375
            path: Union[TypePath, Sequence[TypePath]]
376
            ) -> Optional[Union[Path, List[Path]]]:
377
        if path is None:
378
            return None
379
        if isinstance(path, (str, Path)):
380
            return self._parse_single_path(path)
381
        else:
382
            return [self._parse_single_path(p) for p in path]
383
384
    def _parse_tensor(
385
            self,
386
            tensor: TypeData,
387
            none_ok: bool = True,
388
            ) -> torch.Tensor:
389
        if tensor is None:
390
            if none_ok:
391
                return None
392
            else:
393
                raise RuntimeError('Input tensor cannot be None')
394
        if isinstance(tensor, np.ndarray):
395
            tensor = check_uint_to_int(tensor)
396
            tensor = torch.as_tensor(tensor)
397
        elif not isinstance(tensor, torch.Tensor):
398
            message = 'Input tensor must be a PyTorch tensor or NumPy array'
399
            raise TypeError(message)
400
        ndim = tensor.ndim
401
        if ndim != 4:
402
            raise ValueError(f'Input tensor must be 4D, but it is {ndim}D')
403
        if tensor.dtype == torch.bool:
404
            tensor = tensor.to(torch.uint8)
405
        if self.check_nans and torch.isnan(tensor).any():
406
            warnings.warn(f'NaNs found in tensor', RuntimeWarning)
407
        return tensor
408
409
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
410
        return ensure_4d(tensor)
411
412
    @staticmethod
413
    def _parse_affine(affine: TypeData) -> np.ndarray:
414
        if affine is None:
415
            return np.eye(4)
416
        if isinstance(affine, torch.Tensor):
417
            affine = affine.numpy()
418
        if not isinstance(affine, np.ndarray):
419
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
420
        if affine.shape != (4, 4):
421
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
422
        return affine
423
424
    def load(self) -> None:
425
        r"""Load the image from disk.
426
427
        Returns:
428
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
429
            :math:`4 \times 4` affine matrix to convert voxel indices to world
430
            coordinates.
431
        """
432
        if self._loaded:
433
            return
434
        paths = self.path if isinstance(self.path, list) else [self.path]
435
        tensor, affine = self.read_and_check(paths[0])
436
        tensors = [tensor]
437
        for path in paths[1:]:
438
            new_tensor, new_affine = self.read_and_check(path)
439
            if not np.array_equal(affine, new_affine):
440
                message = (
441
                    'Files have different affine matrices.'
442
                    f'\nMatrix of {paths[0]}:'
443
                    f'\n{affine}'
444
                    f'\nMatrix of {path}:'
445
                    f'\n{new_affine}'
446
                )
447
                warnings.warn(message, RuntimeWarning)
448
            if not tensor.shape[1:] == new_tensor.shape[1:]:
449
                message = (
450
                    f'Files shape do not match, found {tensor.shape}'
451
                    f'and {new_tensor.shape}'
452
                )
453
                RuntimeError(message)
454
            tensors.append(new_tensor)
455
        tensor = torch.cat(tensors)
456
        self.set_data(tensor)
457
        self.affine = affine
458
        self._loaded = True
459
460
    def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
461
        tensor, affine = self.reader(path)
462
        tensor = self.parse_tensor_shape(tensor)
463
        tensor = self._parse_tensor(tensor)
464
        affine = self._parse_affine(affine)
465
        if self.channels_last:
466
            tensor = tensor.permute(3, 0, 1, 2)
467
        if self.check_nans and torch.isnan(tensor).any():
468
            warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
469
        return tensor, affine
470
471
    def save(self, path: TypePath, squeeze: bool = True) -> None:
472
        """Save image to disk.
473
474
        Args:
475
            path: String or instance of :class:`pathlib.Path`.
476
            squeeze: If ``True``, singleton dimensions will be removed
477
                before saving.
478
        """
479
        write_image(
480
            self.data,
481
            self.affine,
482
            path,
483
            squeeze=squeeze,
484
        )
485
486
    def is_2d(self) -> bool:
487
        return self.shape[-1] == 1
488
489
    def numpy(self) -> np.ndarray:
490
        """Get a NumPy array containing the image data."""
491
        return np.asarray(self)
492
493
    def as_sitk(self, **kwargs) -> sitk.Image:
494
        """Get the image as an instance of :class:`sitk.Image`."""
495
        return nib_to_sitk(self.data, self.affine, **kwargs)
496
497
    def as_pil(self, transpose=True):
498
        """Get the image as an instance of :class:`PIL.Image`.
499
500
        .. note:: Values will be clamped to 0-255 and cast to uint8.
501
        .. note:: To use this method, `Pillow` needs to be installed:
502
            `pip install Pillow`.
503
        """
504
        try:
505
            from PIL import Image as ImagePIL
506
        except ModuleNotFoundError as e:
507
            message = (
508
                'Please install Pillow to use Image.as_pil():'
509
                ' pip install Pillow'
510
            )
511
            raise RuntimeError(message) from e
512
513
        self.check_is_2d()
514
        tensor = self.data
515
        if len(tensor) == 1:
516
            tensor = torch.cat(3 * [tensor])
517
        if len(tensor) != 3:
518
            raise RuntimeError('The image must have 1 or 3 channels')
519
        if transpose:
520
            tensor = tensor.permute(3, 2, 1, 0)
521
        else:
522
            tensor = tensor.permute(3, 1, 2, 0)
523
        array = tensor.clamp(0, 255).numpy()[0]
524
        return ImagePIL.fromarray(array.astype(np.uint8))
525
526
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
527
        """Get image center in RAS+ or LPS+ coordinates.
528
529
        Args:
530
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
531
                the first dimension grows towards the left, etc. Otherwise, the
532
                coordinates will be in RAS+ orientation.
533
        """
534
        size = np.array(self.spatial_shape)
535
        center_index = (size - 1) / 2
536
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
537
        if lps:
538
            return (-r, -a, s)
539
        else:
540
            return (r, a, s)
541
542
    def set_check_nans(self, check_nans: bool) -> None:
543
        self.check_nans = check_nans
544
545
    def plot(self, **kwargs) -> None:
546
        """Plot image."""
547
        if self.is_2d():
548
            self.as_pil().show()
549
        else:
550
            from ..visualization import plot_volume  # avoid circular import
551
            plot_volume(self, **kwargs)
552
553
554
class ScalarImage(Image):
555
    """Image whose pixel values represent scalars.
556
557
    Example:
558
        >>> import torch
559
        >>> import torchio as tio
560
        >>> # Loading from a file
561
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
562
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
563
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
564
        >>> data, affine = image.data, image.affine
565
        >>> affine.shape
566
        (4, 4)
567
        >>> image.data is image[tio.DATA]
568
        True
569
        >>> image.data is image.tensor
570
        True
571
        >>> type(image.data)
572
        torch.Tensor
573
574
    See :class:`~torchio.Image` for more information.
575
    """
576
    def __init__(self, *args, **kwargs):
577
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
578
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
579
        kwargs.update({'type': INTENSITY})
580
        super().__init__(*args, **kwargs)
581
582
583
class LabelMap(Image):
584
    """Image whose pixel values represent categorical labels.
585
586
    Intensity transforms are not applied to these images.
587
588
    Example:
589
        >>> import torch
590
        >>> import torchio as tio
591
        >>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5)
592
        >>> labels = tio.LabelMap('t1_seg.nii.gz')  # loading from a file
593
        >>> tpm = tio.LabelMap(                     # loading from files
594
        ...     'gray_matter.nii.gz',
595
        ...     'white_matter.nii.gz',
596
        ...     'csf.nii.gz',
597
        ... )
598
599
    See :class:`~torchio.Image` for more information.
600
    """
601
    def __init__(self, *args, **kwargs):
602
        if 'type' in kwargs and kwargs['type'] != LABEL:
603
            raise ValueError('Type of LabelMap is always torchio.LABEL')
604
        kwargs.update({'type': LABEL})
605
        super().__init__(*args, **kwargs)
606