Passed
Push — master ( 2c67ba...a6348a )
by Fernando
01:11
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 42
Code Lines 34

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 34
nop 8
dl 0
loc 42
rs 6.6666
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
from PIL import Image as ImagePIL
9
import nibabel as nib
10
import SimpleITK as sitk
11
12
from ..utils import get_stem
13
from ..typing import TypeData, TypePath, TypeTripletInt, TypeTripletFloat
14
from ..constants import DATA, TYPE, AFFINE, PATH, STEM, INTENSITY, LABEL
15
from .io import (
16
    ensure_4d,
17
    read_image,
18
    write_image,
19
    nib_to_sitk,
20
    check_uint_to_int,
21
    get_rotation_and_spacing_from_affine,
22
)
23
24
25
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
26
TypeBound = Tuple[float, float]
27
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
28
29
30
class Image(dict):
31
    r"""TorchIO image.
32
33
    For information about medical image orientation, check out `NiBabel docs`_,
34
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
35
    `SimpleITK docs`_.
36
37
    Args:
38
        path: Path to a file or sequence of paths to files that can be read by
39
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
40
            DICOM files. If :attr:`tensor` is given, the data in
41
            :attr:`path` will not be read.
42
            If a sequence of paths is given, data
43
            will be concatenated on the channel dimension so spatial
44
            dimensions must match.
45
        type: Type of image, such as :attr:`torchio.INTENSITY` or
46
            :attr:`torchio.LABEL`. This will be used by the transforms to
47
            decide whether to apply an operation, or which interpolation to use
48
            when resampling. For example, `preprocessing`_ and `augmentation`_
49
            intensity transforms will only be applied to images with type
50
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
51
            all types, and nearest neighbor interpolation is always used to
52
            resample images with type :attr:`torchio.LABEL`.
53
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
54
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
55
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
56
            :class:`torch.Tensor` or NumPy array with dimensions
57
            :math:`(C, W, H, D)`.
58
        affine: If :attr:`path` is not given, :attr:`affine` must be a
59
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
60
            identity matrix.
61
        check_nans: If ``True``, issues a warning if NaNs are found
62
            in the image. If ``False``, images will not be checked for the
63
            presence of NaNs.
64
        **kwargs: Items that will be added to the image dictionary, e.g.
65
            acquisition parameters.
66
67
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
68
    when needed.
69
70
    Example:
71
        >>> import torchio as tio
72
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
73
        >>> image  # not loaded yet
74
        ScalarImage(path: t1.nii.gz; type: intensity)
75
        >>> times_two = 2 * image.data  # data is loaded and cached here
76
        >>> image
77
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
78
        >>> image.save('doubled_image.nii.gz')
79
80
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
81
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
82
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
83
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
84
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
85
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
86
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
87
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
88
    """
89
    def __init__(
90
            self,
91
            path: Union[TypePath, Sequence[TypePath], None] = None,
92
            type: str = None,
93
            tensor: Optional[TypeData] = None,
94
            affine: Optional[TypeData] = None,
95
            check_nans: bool = False,  # removed by ITK by default
96
            channels_last: bool = False,
97
            **kwargs: Dict[str, Any],
98
            ):
99
        self.check_nans = check_nans
100
        self.channels_last = channels_last
101
102
        if type is None:
103
            warnings.warn(
104
                'Not specifying the image type is deprecated and will be'
105
                ' mandatory in the future. You can probably use tio.ScalarImage'
106
                'or LabelMap instead', DeprecationWarning,
107
            )
108
            type = INTENSITY
109
110
        if path is None and tensor is None:
111
            raise ValueError('A value for path or tensor must be given')
112
        self._loaded = False
113
114
        tensor = self._parse_tensor(tensor)
115
        affine = self._parse_affine(affine)
116
        if tensor is not None:
117
            self.data = tensor
118
            self.affine = affine
119
            self._loaded = True
120
        for key in PROTECTED_KEYS:
121
            if key in kwargs:
122
                message = f'Key "{key}" is reserved. Use a different one'
123
                raise ValueError(message)
124
125
        super().__init__(**kwargs)
126
        self.path = self._parse_path(path)
127
128
        self[PATH] = '' if self.path is None else str(self.path)
129
        self[STEM] = '' if self.path is None else get_stem(self.path)
130
        self[TYPE] = type
131
132
    def __repr__(self):
133
        properties = []
134
        if self._loaded:
135
            properties.extend([
136
                f'shape: {self.shape}',
137
                f'spacing: {self.get_spacing_string()}',
138
                f'orientation: {"".join(self.orientation)}+',
139
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
140
            ])
141
        else:
142
            properties.append(f'path: "{self.path}"')
143
        if self._loaded:
144
            properties.append(f'dtype: {self.data.type()}')
145
        properties = '; '.join(properties)
146
        string = f'{self.__class__.__name__}({properties})'
147
        return string
148
149
    def __getitem__(self, item):
150
        if item in (DATA, AFFINE):
151
            if item not in self:
152
                self.load()
153
        return super().__getitem__(item)
154
155
    def __array__(self):
156
        return self.data.numpy()
157
158
    def __copy__(self):
159
        kwargs = dict(
160
            tensor=self.data,
161
            affine=self.affine,
162
            type=self.type,
163
            path=self.path,
164
        )
165
        for key, value in self.items():
166
            if key in PROTECTED_KEYS: continue
167
            kwargs[key] = value  # should I copy? deepcopy?
168
        return self.__class__(**kwargs)
169
170
    @property
171
    def data(self) -> torch.Tensor:
172
        """Tensor data. Same as :class:`Image.tensor`."""
173
        return self[DATA]
174
175
    @data.setter
176
    def data(self, tensor: TypeData):
177
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
178
179
    @property
180
    def tensor(self) -> torch.Tensor:
181
        """Tensor data. Same as :class:`Image.data`."""
182
        return self.data
183
184
    @property
185
    def affine(self) -> np.ndarray:
186
        """Affine matrix to transform voxel indices into world coordinates."""
187
        return self[AFFINE]
188
189
    @affine.setter
190
    def affine(self, matrix):
191
        self[AFFINE] = self._parse_affine(matrix)
192
193
    @property
194
    def type(self) -> str:
195
        return self[TYPE]
196
197
    @property
198
    def shape(self) -> Tuple[int, int, int, int]:
199
        """Tensor shape as :math:`(C, W, H, D)`."""
200
        return tuple(self.data.shape)
201
202
    @property
203
    def spatial_shape(self) -> TypeTripletInt:
204
        """Tensor spatial shape as :math:`(W, H, D)`."""
205
        return self.shape[1:]
206
207
    def check_is_2d(self) -> None:
208
        if not self.is_2d():
209
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
210
            raise RuntimeError(message)
211
212
    @property
213
    def height(self) -> int:
214
        """Image height, if 2D."""
215
        self.check_is_2d()
216
        return self.spatial_shape[1]
217
218
    @property
219
    def width(self) -> int:
220
        """Image width, if 2D."""
221
        self.check_is_2d()
222
        return self.spatial_shape[0]
223
224
    @property
225
    def orientation(self) -> Tuple[str, str, str]:
226
        """Orientation codes."""
227
        return nib.aff2axcodes(self.affine)
228
229
    @property
230
    def spacing(self) -> Tuple[float, float, float]:
231
        """Voxel spacing in mm."""
232
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
233
        return tuple(spacing)
234
235
    @property
236
    def memory(self) -> float:
237
        """Number of Bytes that the tensor takes in the RAM."""
238
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
239
240
    @property
241
    def bounds(self) -> np.ndarray:
242
        """Position of centers of voxels in smallest and largest coordinates."""
243
        ini = 0, 0, 0
244
        fin = np.array(self.spatial_shape) - 1
245
        point_ini = nib.affines.apply_affine(self.affine, ini)
246
        point_fin = nib.affines.apply_affine(self.affine, fin)
247
        return np.array((point_ini, point_fin))
248
249
    def axis_name_to_index(self, axis: str) -> int:
250
        """Convert an axis name to an axis index.
251
252
        Args:
253
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
254
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
255
                versions and first letters are also valid, as only the first
256
                letter will be used.
257
258
        .. note:: If you are working with animals, you should probably use
259
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
260
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
261
            respectively.
262
263
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
264
            ``'Left'`` and ``'Right'``.
265
        """
266
        # Top and bottom are used for the vertical 2D axis as the use of
267
        # Height vs Horizontal might be ambiguous
268
269
        if not isinstance(axis, str):
270
            raise ValueError('Axis must be a string')
271
        axis = axis[0].upper()
272
273
        # Generally, TorchIO tensors are (C, W, H, D)
274
        if axis in 'TB':  # Top, Bottom
275
            return -2
276
        else:
277
            try:
278
                index = self.orientation.index(axis)
279
            except ValueError:
280
                index = self.orientation.index(self.flip_axis(axis))
281
            # Return negative indices so that it does not matter whether we
282
            # refer to spatial dimensions or not
283
            index = -3 + index
284
            return index
285
286
    # flake8: noqa: E701
287
    @staticmethod
288
    def flip_axis(axis: str) -> str:
289
        if axis == 'R': flipped_axis = 'L'
290
        elif axis == 'L': flipped_axis = 'R'
291
        elif axis == 'A': flipped_axis = 'P'
292
        elif axis == 'P': flipped_axis = 'A'
293
        elif axis == 'I': flipped_axis = 'S'
294
        elif axis == 'S': flipped_axis = 'I'
295
        elif axis == 'T': flipped_axis = 'B'
296
        elif axis == 'B': flipped_axis = 'T'
297
        else:
298
            values = ', '.join('LRPAISTB')
299
            message = f'Axis not understood. Please use one of: {values}'
300
            raise ValueError(message)
301
        return flipped_axis
302
303
    def get_spacing_string(self) -> str:
304
        strings = [f'{n:.2f}' for n in self.spacing]
305
        string = f'({", ".join(strings)})'
306
        return string
307
308
    def get_bounds(self) -> TypeBounds:
309
        """Get minimum and maximum world coordinates occupied by the image."""
310
        first_index = 3 * (-0.5,)
311
        last_index = np.array(self.spatial_shape) - 0.5
312
        first_point = nib.affines.apply_affine(self.affine, first_index)
313
        last_point = nib.affines.apply_affine(self.affine, last_index)
314
        array = np.array((first_point, last_point))
315
        bounds_x, bounds_y, bounds_z = array.T.tolist()
316
        return bounds_x, bounds_y, bounds_z
317
318
    @staticmethod
319
    def _parse_single_path(
320
            path: TypePath
321
            ) -> Path:
322
        try:
323
            path = Path(path).expanduser()
324
        except TypeError:
325
            message = (
326
                f'Expected type str or Path but found {path} with '
327
                f'{type(path)} instead'
328
            )
329
            raise TypeError(message)
330
        except RuntimeError:
331
            message = (
332
                f'Conversion to path not possible for variable: {path}'
333
            )
334
            raise RuntimeError(message)
335
336
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
337
            raise FileNotFoundError(f'File not found: "{path}"')
338
        return path
339
340
    def _parse_path(
341
            self,
342
            path: Union[TypePath, Sequence[TypePath]]
343
            ) -> Optional[Union[Path, List[Path]]]:
344
        if path is None:
345
            return None
346
        if isinstance(path, (str, Path)):
347
            return self._parse_single_path(path)
348
        else:
349
            return [self._parse_single_path(p) for p in path]
350
351
    def _parse_tensor(
352
            self,
353
            tensor: TypeData,
354
            none_ok: bool = True,
355
            ) -> torch.Tensor:
356
        if tensor is None:
357
            if none_ok:
358
                return None
359
            else:
360
                raise RuntimeError('Input tensor cannot be None')
361
        if isinstance(tensor, np.ndarray):
362
            tensor = check_uint_to_int(tensor)
363
            tensor = torch.from_numpy(tensor)
364
        elif not isinstance(tensor, torch.Tensor):
365
            message = 'Input tensor must be a PyTorch tensor or NumPy array'
366
            raise TypeError(message)
367
        if tensor.ndim != 4:
368
            raise ValueError('Input tensor must be 4D')
369
        if tensor.dtype == torch.bool:
370
            tensor = tensor.to(torch.uint8)
371
        if self.check_nans and torch.isnan(tensor).any():
372
            warnings.warn(f'NaNs found in tensor', RuntimeWarning)
373
        return tensor
374
375
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
376
        return ensure_4d(tensor)
377
378
    @staticmethod
379
    def _parse_affine(affine: np.ndarray) -> np.ndarray:
380
        if affine is None:
381
            return np.eye(4)
382
        if not isinstance(affine, np.ndarray):
383
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
384
        if affine.shape != (4, 4):
385
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
386
        return affine
387
388
    def load(self) -> None:
389
        r"""Load the image from disk.
390
391
        Returns:
392
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
393
            :math:`4 \times 4` affine matrix to convert voxel indices to world
394
            coordinates.
395
        """
396
        if self._loaded:
397
            return
398
        paths = self.path if isinstance(self.path, list) else [self.path]
399
        tensor, affine = self.read_and_check(paths[0])
400
        tensors = [tensor]
401
        for path in paths[1:]:
402
            new_tensor, new_affine = self.read_and_check(path)
403
            if not np.array_equal(affine, new_affine):
404
                message = (
405
                    'Files have different affine matrices.'
406
                    f'\nMatrix of {paths[0]}:'
407
                    f'\n{affine}'
408
                    f'\nMatrix of {path}:'
409
                    f'\n{new_affine}'
410
                )
411
                warnings.warn(message, RuntimeWarning)
412
            if not tensor.shape[1:] == new_tensor.shape[1:]:
413
                message = (
414
                    f'Files shape do not match, found {tensor.shape}'
415
                    f'and {new_tensor.shape}'
416
                )
417
                RuntimeError(message)
418
            tensors.append(new_tensor)
419
        tensor = torch.cat(tensors)
420
        self.data = tensor
421
        self.affine = affine
422
        self._loaded = True
423
424
    def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
425
        tensor, affine = read_image(path)
426
        tensor = self.parse_tensor_shape(tensor)
427
        if self.channels_last:
428
            tensor = tensor.permute(3, 0, 1, 2)
429
        if self.check_nans and torch.isnan(tensor).any():
430
            warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
431
        return tensor, affine
432
433
    def save(self, path: TypePath, squeeze: bool = True) -> None:
434
        """Save image to disk.
435
436
        Args:
437
            path: String or instance of :class:`pathlib.Path`.
438
            squeeze: If ``True``, singleton dimensions will be removed
439
                before saving.
440
        """
441
        write_image(
442
            self.data,
443
            self.affine,
444
            path,
445
            squeeze=squeeze,
446
        )
447
448
    def is_2d(self) -> bool:
449
        return self.shape[-1] == 1
450
451
    def numpy(self) -> np.ndarray:
452
        """Get a NumPy array containing the image data."""
453
        return np.asarray(self)
454
455
    def as_sitk(self, **kwargs) -> sitk.Image:
456
        """Get the image as an instance of :class:`sitk.Image`."""
457
        return nib_to_sitk(self.data, self.affine, **kwargs)
458
459
    def as_pil(self) -> ImagePIL:
460
        """Get the image as an instance of :class:`PIL.Image`.
461
462
        .. note:: Values will be clamped to 0-255 and cast to uint8.
463
        """
464
        self.check_is_2d()
465
        tensor = self.data
466
        if len(tensor) == 1:
467
            tensor = torch.cat(3 * [tensor])
468
        if len(tensor) != 3:
469
            raise RuntimeError('The image must have 1 or 3 channels')
470
        tensor = tensor.permute(3, 1, 2, 0)[0]
471
        array = tensor.clamp(0, 255).numpy()
472
        return ImagePIL.fromarray(array.astype(np.uint8))
473
474
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
475
        """Get image center in RAS+ or LPS+ coordinates.
476
477
        Args:
478
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
479
                the first dimension grows towards the left, etc. Otherwise, the
480
                coordinates will be in RAS+ orientation.
481
        """
482
        size = np.array(self.spatial_shape)
483
        center_index = (size - 1) / 2
484
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
485
        if lps:
486
            return (-r, -a, s)
487
        else:
488
            return (r, a, s)
489
490
    def set_check_nans(self, check_nans: bool) -> None:
491
        self.check_nans = check_nans
492
493
    def plot(self, **kwargs) -> None:
494
        """Plot image."""
495
        if self.is_2d():
496
            self.as_pil().show()
497
        else:
498
            from ..visualization import plot_volume  # avoid circular import
499
            plot_volume(self, **kwargs)
500
501
502
class ScalarImage(Image):
503
    """Image whose pixel values represent scalars.
504
505
    Example:
506
        >>> import torch
507
        >>> import torchio as tio
508
        >>> # Loading from a file
509
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
510
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
511
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
512
        >>> data, affine = image.data, image.affine
513
        >>> affine.shape
514
        (4, 4)
515
        >>> image.data is image[tio.DATA]
516
        True
517
        >>> image.data is image.tensor
518
        True
519
        >>> type(image.data)
520
        torch.Tensor
521
522
    See :class:`~torchio.Image` for more information.
523
    """
524
    def __init__(self, *args, **kwargs):
525
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
526
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
527
        kwargs.update({'type': INTENSITY})
528
        super().__init__(*args, **kwargs)
529
530
531
class LabelMap(Image):
532
    """Image whose pixel values represent categorical labels.
533
534
    Intensity transforms are not applied to these images.
535
536
    Example:
537
        >>> import torch
538
        >>> import torchio as tio
539
        >>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5)
540
        >>> labels = tio.LabelMap('t1_seg.nii.gz')  # loading from a file
541
        >>> tpm = tio.LabelMap(                     # loading from files
542
        ...     'gray_matter.nii.gz',
543
        ...     'white_matter.nii.gz',
544
        ...     'csf.nii.gz',
545
        ... )
546
547
    See :class:`~torchio.Image` for more information.
548
    """
549
    def __init__(self, *args, **kwargs):
550
        if 'type' in kwargs and kwargs['type'] != LABEL:
551
            raise ValueError('Type of LabelMap is always torchio.LABEL')
552
        kwargs.update({'type': LABEL})
553
        super().__init__(*args, **kwargs)
554