Passed
Push — master ( e8c1dc...d14fe8 )
by Fernando
01:15
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 42
Code Lines 34

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 34
nop 8
dl 0
loc 42
rs 6.6666
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
from deprecated import deprecated
11
12
from ..utils import get_stem
13
from ..typing import TypeData, TypePath, TypeTripletInt, TypeTripletFloat
14
from ..constants import DATA, TYPE, AFFINE, PATH, STEM, INTENSITY, LABEL
15
from .io import (
16
    ensure_4d,
17
    read_image,
18
    write_image,
19
    nib_to_sitk,
20
    check_uint_to_int,
21
    get_rotation_and_spacing_from_affine,
22
)
23
24
25
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
26
TypeBound = Tuple[float, float]
27
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
28
29
deprecation_message = (
30
    'Setting the image data with the property setter is deprecated. Use the'
31
    ' set_data() method instead'
32
)
33
34
35
class Image(dict):
36
    r"""TorchIO image.
37
38
    For information about medical image orientation, check out `NiBabel docs`_,
39
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
40
    `SimpleITK docs`_.
41
42
    Args:
43
        path: Path to a file or sequence of paths to files that can be read by
44
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
45
            DICOM files. If :attr:`tensor` is given, the data in
46
            :attr:`path` will not be read.
47
            If a sequence of paths is given, data
48
            will be concatenated on the channel dimension so spatial
49
            dimensions must match.
50
        type: Type of image, such as :attr:`torchio.INTENSITY` or
51
            :attr:`torchio.LABEL`. This will be used by the transforms to
52
            decide whether to apply an operation, or which interpolation to use
53
            when resampling. For example, `preprocessing`_ and `augmentation`_
54
            intensity transforms will only be applied to images with type
55
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
56
            all types, and nearest neighbor interpolation is always used to
57
            resample images with type :attr:`torchio.LABEL`.
58
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
59
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
60
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
61
            :class:`torch.Tensor` or NumPy array with dimensions
62
            :math:`(C, W, H, D)`.
63
        affine: If :attr:`path` is not given, :attr:`affine` must be a
64
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
65
            identity matrix.
66
        check_nans: If ``True``, issues a warning if NaNs are found
67
            in the image. If ``False``, images will not be checked for the
68
            presence of NaNs.
69
        **kwargs: Items that will be added to the image dictionary, e.g.
70
            acquisition parameters.
71
72
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
73
    when needed.
74
75
    Example:
76
        >>> import torchio as tio
77
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
78
        >>> image  # not loaded yet
79
        ScalarImage(path: t1.nii.gz; type: intensity)
80
        >>> times_two = 2 * image.data  # data is loaded and cached here
81
        >>> image
82
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
83
        >>> image.save('doubled_image.nii.gz')
84
85
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
86
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
87
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
88
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
89
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
90
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
91
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
92
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
93
    """
94
    def __init__(
95
            self,
96
            path: Union[TypePath, Sequence[TypePath], None] = None,
97
            type: str = None,
98
            tensor: Optional[TypeData] = None,
99
            affine: Optional[TypeData] = None,
100
            check_nans: bool = False,  # removed by ITK by default
101
            channels_last: bool = False,
102
            **kwargs: Dict[str, Any],
103
            ):
104
        self.check_nans = check_nans
105
        self.channels_last = channels_last
106
107
        if type is None:
108
            warnings.warn(
109
                'Not specifying the image type is deprecated and will be'
110
                ' mandatory in the future. You can probably use tio.ScalarImage'
111
                ' or tio.LabelMap instead', DeprecationWarning,
112
            )
113
            type = INTENSITY
114
115
        if path is None and tensor is None:
116
            raise ValueError('A value for path or tensor must be given')
117
        self._loaded = False
118
119
        tensor = self._parse_tensor(tensor)
120
        affine = self._parse_affine(affine)
121
        if tensor is not None:
122
            self.set_data(tensor)
123
            self.affine = affine
124
            self._loaded = True
125
        for key in PROTECTED_KEYS:
126
            if key in kwargs:
127
                message = f'Key "{key}" is reserved. Use a different one'
128
                raise ValueError(message)
129
130
        super().__init__(**kwargs)
131
        self.path = self._parse_path(path)
132
133
        self[PATH] = '' if self.path is None else str(self.path)
134
        self[STEM] = '' if self.path is None else get_stem(self.path)
135
        self[TYPE] = type
136
137
    def __repr__(self):
138
        properties = []
139
        if self._loaded:
140
            properties.extend([
141
                f'shape: {self.shape}',
142
                f'spacing: {self.get_spacing_string()}',
143
                f'orientation: {"".join(self.orientation)}+',
144
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
145
            ])
146
        else:
147
            properties.append(f'path: "{self.path}"')
148
        if self._loaded:
149
            properties.append(f'dtype: {self.data.type()}')
150
        properties = '; '.join(properties)
151
        string = f'{self.__class__.__name__}({properties})'
152
        return string
153
154
    def __getitem__(self, item):
155
        if item in (DATA, AFFINE):
156
            if item not in self:
157
                self.load()
158
        return super().__getitem__(item)
159
160
    def __array__(self):
161
        return self.data.numpy()
162
163
    def __copy__(self):
164
        kwargs = dict(
165
            tensor=self.data,
166
            affine=self.affine,
167
            type=self.type,
168
            path=self.path,
169
        )
170
        for key, value in self.items():
171
            if key in PROTECTED_KEYS: continue
172
            kwargs[key] = value  # should I copy? deepcopy?
173
        return self.__class__(**kwargs)
174
175
    @property
176
    def data(self) -> torch.Tensor:
177
        """Tensor data. Same as :class:`Image.tensor`."""
178
        return self[DATA]
179
180
    @data.setter
181
    @deprecated(version='0.18.16', reason=deprecation_message)
182
    def data(self, tensor: TypeData):
183
        self.set_data(tensor)
184
185
    def set_data(self, tensor: TypeData):
186
        """Store a 4D tensor in the :attr:`data` key and attribute.
187
188
        Args:
189
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
190
        """
191
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
192
193
    @property
194
    def tensor(self) -> torch.Tensor:
195
        """Tensor data. Same as :class:`Image.data`."""
196
        return self.data
197
198
    @property
199
    def affine(self) -> np.ndarray:
200
        """Affine matrix to transform voxel indices into world coordinates."""
201
        return self[AFFINE]
202
203
    @affine.setter
204
    def affine(self, matrix):
205
        self[AFFINE] = self._parse_affine(matrix)
206
207
    @property
208
    def type(self) -> str:
209
        return self[TYPE]
210
211
    @property
212
    def shape(self) -> Tuple[int, int, int, int]:
213
        """Tensor shape as :math:`(C, W, H, D)`."""
214
        return tuple(self.data.shape)
215
216
    @property
217
    def spatial_shape(self) -> TypeTripletInt:
218
        """Tensor spatial shape as :math:`(W, H, D)`."""
219
        return self.shape[1:]
220
221
    def check_is_2d(self) -> None:
222
        if not self.is_2d():
223
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
224
            raise RuntimeError(message)
225
226
    @property
227
    def height(self) -> int:
228
        """Image height, if 2D."""
229
        self.check_is_2d()
230
        return self.spatial_shape[1]
231
232
    @property
233
    def width(self) -> int:
234
        """Image width, if 2D."""
235
        self.check_is_2d()
236
        return self.spatial_shape[0]
237
238
    @property
239
    def orientation(self) -> Tuple[str, str, str]:
240
        """Orientation codes."""
241
        return nib.aff2axcodes(self.affine)
242
243
    @property
244
    def spacing(self) -> Tuple[float, float, float]:
245
        """Voxel spacing in mm."""
246
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
247
        return tuple(spacing)
248
249
    @property
250
    def itemsize(self):
251
        """Element size of the data type."""
252
        return self.data.element_size()
253
254
    @property
255
    def memory(self) -> float:
256
        """Number of Bytes that the tensor takes in the RAM."""
257
        return np.prod(self.shape) * self.itemsize
258
259
    @property
260
    def bounds(self) -> np.ndarray:
261
        """Position of centers of voxels in smallest and largest coordinates."""
262
        ini = 0, 0, 0
263
        fin = np.array(self.spatial_shape) - 1
264
        point_ini = nib.affines.apply_affine(self.affine, ini)
265
        point_fin = nib.affines.apply_affine(self.affine, fin)
266
        return np.array((point_ini, point_fin))
267
268
    def axis_name_to_index(self, axis: str) -> int:
269
        """Convert an axis name to an axis index.
270
271
        Args:
272
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
273
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
274
                versions and first letters are also valid, as only the first
275
                letter will be used.
276
277
        .. note:: If you are working with animals, you should probably use
278
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
279
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
280
            respectively.
281
282
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
283
            ``'Left'`` and ``'Right'``.
284
        """
285
        # Top and bottom are used for the vertical 2D axis as the use of
286
        # Height vs Horizontal might be ambiguous
287
288
        if not isinstance(axis, str):
289
            raise ValueError('Axis must be a string')
290
        axis = axis[0].upper()
291
292
        # Generally, TorchIO tensors are (C, W, H, D)
293
        if axis in 'TB':  # Top, Bottom
294
            return -2
295
        else:
296
            try:
297
                index = self.orientation.index(axis)
298
            except ValueError:
299
                index = self.orientation.index(self.flip_axis(axis))
300
            # Return negative indices so that it does not matter whether we
301
            # refer to spatial dimensions or not
302
            index = -3 + index
303
            return index
304
305
    # flake8: noqa: E701
306
    @staticmethod
307
    def flip_axis(axis: str) -> str:
308
        if axis == 'R': flipped_axis = 'L'
309
        elif axis == 'L': flipped_axis = 'R'
310
        elif axis == 'A': flipped_axis = 'P'
311
        elif axis == 'P': flipped_axis = 'A'
312
        elif axis == 'I': flipped_axis = 'S'
313
        elif axis == 'S': flipped_axis = 'I'
314
        elif axis == 'T': flipped_axis = 'B'
315
        elif axis == 'B': flipped_axis = 'T'
316
        else:
317
            values = ', '.join('LRPAISTB')
318
            message = f'Axis not understood. Please use one of: {values}'
319
            raise ValueError(message)
320
        return flipped_axis
321
322
    def get_spacing_string(self) -> str:
323
        strings = [f'{n:.2f}' for n in self.spacing]
324
        string = f'({", ".join(strings)})'
325
        return string
326
327
    def get_bounds(self) -> TypeBounds:
328
        """Get minimum and maximum world coordinates occupied by the image."""
329
        first_index = 3 * (-0.5,)
330
        last_index = np.array(self.spatial_shape) - 0.5
331
        first_point = nib.affines.apply_affine(self.affine, first_index)
332
        last_point = nib.affines.apply_affine(self.affine, last_index)
333
        array = np.array((first_point, last_point))
334
        bounds_x, bounds_y, bounds_z = array.T.tolist()
335
        return bounds_x, bounds_y, bounds_z
336
337
    @staticmethod
338
    def _parse_single_path(
339
            path: TypePath
340
            ) -> Path:
341
        try:
342
            path = Path(path).expanduser()
343
        except TypeError:
344
            message = (
345
                f'Expected type str or Path but found {path} with type'
346
                f' {type(path)} instead'
347
            )
348
            raise TypeError(message)
349
        except RuntimeError:
350
            message = (
351
                f'Conversion to path not possible for variable: {path}'
352
            )
353
            raise RuntimeError(message)
354
355
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
356
            raise FileNotFoundError(f'File not found: "{path}"')
357
        return path
358
359
    def _parse_path(
360
            self,
361
            path: Union[TypePath, Sequence[TypePath]]
362
            ) -> Optional[Union[Path, List[Path]]]:
363
        if path is None:
364
            return None
365
        if isinstance(path, (str, Path)):
366
            return self._parse_single_path(path)
367
        else:
368
            return [self._parse_single_path(p) for p in path]
369
370
    def _parse_tensor(
371
            self,
372
            tensor: TypeData,
373
            none_ok: bool = True,
374
            ) -> torch.Tensor:
375
        if tensor is None:
376
            if none_ok:
377
                return None
378
            else:
379
                raise RuntimeError('Input tensor cannot be None')
380
        if isinstance(tensor, np.ndarray):
381
            tensor = check_uint_to_int(tensor)
382
            tensor = torch.from_numpy(tensor)
383
        elif not isinstance(tensor, torch.Tensor):
384
            message = 'Input tensor must be a PyTorch tensor or NumPy array'
385
            raise TypeError(message)
386
        if tensor.ndim != 4:
387
            raise ValueError('Input tensor must be 4D')
388
        if tensor.dtype == torch.bool:
389
            tensor = tensor.to(torch.uint8)
390
        if self.check_nans and torch.isnan(tensor).any():
391
            warnings.warn(f'NaNs found in tensor', RuntimeWarning)
392
        return tensor
393
394
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
395
        return ensure_4d(tensor)
396
397
    @staticmethod
398
    def _parse_affine(affine: TypeData) -> np.ndarray:
399
        if affine is None:
400
            return np.eye(4)
401
        if isinstance(affine, torch.Tensor):
402
            affine = affine.numpy()
403
        if not isinstance(affine, np.ndarray):
404
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
405
        if affine.shape != (4, 4):
406
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
407
        return affine
408
409
    def load(self) -> None:
410
        r"""Load the image from disk.
411
412
        Returns:
413
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
414
            :math:`4 \times 4` affine matrix to convert voxel indices to world
415
            coordinates.
416
        """
417
        if self._loaded:
418
            return
419
        paths = self.path if isinstance(self.path, list) else [self.path]
420
        tensor, affine = self.read_and_check(paths[0])
421
        tensors = [tensor]
422
        for path in paths[1:]:
423
            new_tensor, new_affine = self.read_and_check(path)
424
            if not np.array_equal(affine, new_affine):
425
                message = (
426
                    'Files have different affine matrices.'
427
                    f'\nMatrix of {paths[0]}:'
428
                    f'\n{affine}'
429
                    f'\nMatrix of {path}:'
430
                    f'\n{new_affine}'
431
                )
432
                warnings.warn(message, RuntimeWarning)
433
            if not tensor.shape[1:] == new_tensor.shape[1:]:
434
                message = (
435
                    f'Files shape do not match, found {tensor.shape}'
436
                    f'and {new_tensor.shape}'
437
                )
438
                RuntimeError(message)
439
            tensors.append(new_tensor)
440
        tensor = torch.cat(tensors)
441
        self.set_data(tensor)
442
        self.affine = affine
443
        self._loaded = True
444
445
    def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
446
        tensor, affine = read_image(path)
447
        tensor = self.parse_tensor_shape(tensor)
448
        if self.channels_last:
449
            tensor = tensor.permute(3, 0, 1, 2)
450
        if self.check_nans and torch.isnan(tensor).any():
451
            warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
452
        return tensor, affine
453
454
    def save(self, path: TypePath, squeeze: bool = True) -> None:
455
        """Save image to disk.
456
457
        Args:
458
            path: String or instance of :class:`pathlib.Path`.
459
            squeeze: If ``True``, singleton dimensions will be removed
460
                before saving.
461
        """
462
        write_image(
463
            self.data,
464
            self.affine,
465
            path,
466
            squeeze=squeeze,
467
        )
468
469
    def is_2d(self) -> bool:
470
        return self.shape[-1] == 1
471
472
    def numpy(self) -> np.ndarray:
473
        """Get a NumPy array containing the image data."""
474
        return np.asarray(self)
475
476
    def as_sitk(self, **kwargs) -> sitk.Image:
477
        """Get the image as an instance of :class:`sitk.Image`."""
478
        return nib_to_sitk(self.data, self.affine, **kwargs)
479
480
    def as_pil(self):
481
        """Get the image as an instance of :class:`PIL.Image`.
482
483
        .. note:: Values will be clamped to 0-255 and cast to uint8.
484
        .. note:: To use this method, `Pillow` needs to be installed:
485
            `pip install Pillow`.
486
        """
487
        try:
488
            from PIL import Image as ImagePIL
489
        except ModuleNotFoundError as e:
490
            message = (
491
                'Please install Pillow to use Image.as_pil():'
492
                ' pip install Pillow'
493
            )
494
            raise RuntimeError(message) from e
495
496
        self.check_is_2d()
497
        tensor = self.data
498
        if len(tensor) == 1:
499
            tensor = torch.cat(3 * [tensor])
500
        if len(tensor) != 3:
501
            raise RuntimeError('The image must have 1 or 3 channels')
502
        tensor = tensor.permute(3, 1, 2, 0)[0]
503
        array = tensor.clamp(0, 255).numpy()
504
        return ImagePIL.fromarray(array.astype(np.uint8))
505
506
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
507
        """Get image center in RAS+ or LPS+ coordinates.
508
509
        Args:
510
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
511
                the first dimension grows towards the left, etc. Otherwise, the
512
                coordinates will be in RAS+ orientation.
513
        """
514
        size = np.array(self.spatial_shape)
515
        center_index = (size - 1) / 2
516
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
517
        if lps:
518
            return (-r, -a, s)
519
        else:
520
            return (r, a, s)
521
522
    def set_check_nans(self, check_nans: bool) -> None:
523
        self.check_nans = check_nans
524
525
    def plot(self, **kwargs) -> None:
526
        """Plot image."""
527
        if self.is_2d():
528
            self.as_pil().show()
529
        else:
530
            from ..visualization import plot_volume  # avoid circular import
531
            plot_volume(self, **kwargs)
532
533
534
class ScalarImage(Image):
535
    """Image whose pixel values represent scalars.
536
537
    Example:
538
        >>> import torch
539
        >>> import torchio as tio
540
        >>> # Loading from a file
541
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
542
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
543
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
544
        >>> data, affine = image.data, image.affine
545
        >>> affine.shape
546
        (4, 4)
547
        >>> image.data is image[tio.DATA]
548
        True
549
        >>> image.data is image.tensor
550
        True
551
        >>> type(image.data)
552
        torch.Tensor
553
554
    See :class:`~torchio.Image` for more information.
555
    """
556
    def __init__(self, *args, **kwargs):
557
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
558
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
559
        kwargs.update({'type': INTENSITY})
560
        super().__init__(*args, **kwargs)
561
562
563
class LabelMap(Image):
564
    """Image whose pixel values represent categorical labels.
565
566
    Intensity transforms are not applied to these images.
567
568
    Example:
569
        >>> import torch
570
        >>> import torchio as tio
571
        >>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5)
572
        >>> labels = tio.LabelMap('t1_seg.nii.gz')  # loading from a file
573
        >>> tpm = tio.LabelMap(                     # loading from files
574
        ...     'gray_matter.nii.gz',
575
        ...     'white_matter.nii.gz',
576
        ...     'csf.nii.gz',
577
        ... )
578
579
    See :class:`~torchio.Image` for more information.
580
    """
581
    def __init__(self, *args, **kwargs):
582
        if 'type' in kwargs and kwargs['type'] != LABEL:
583
            raise ValueError('Type of LabelMap is always torchio.LABEL')
584
        kwargs.update({'type': LABEL})
585
        super().__init__(*args, **kwargs)
586