Passed
Push — master ( 0bdb47...b0bac6 )
by Fernando
01:14
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 42
Code Lines 34

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 34
nop 8
dl 0
loc 42
rs 6.6666
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
from PIL import Image as ImagePIL
9
import nibabel as nib
10
import SimpleITK as sitk
11
12
from ..utils import (
13
    nib_to_sitk,
14
    get_rotation_and_spacing_from_affine,
15
    get_stem,
16
    ensure_4d,
17
    check_uint_to_int,
18
)
19
from ..torchio import (
20
    TypeData,
21
    TypePath,
22
    TypeTripletInt,
23
    TypeTripletFloat,
24
    DATA,
25
    TYPE,
26
    AFFINE,
27
    PATH,
28
    STEM,
29
    INTENSITY,
30
    LABEL,
31
)
32
from .io import read_image, write_image
33
34
35
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
36
TypeBound = Tuple[float, float]
37
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
38
39
40
class Image(dict):
41
    r"""TorchIO image.
42
43
    For information about medical image orientation, check out `NiBabel docs`_,
44
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
45
    `SimpleITK docs`_.
46
47
    Args:
48
        path: Path to a file or sequence of paths to files that can be read by
49
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
50
            DICOM files. If :attr:`tensor` is given, the data in
51
            :attr:`path` will not be read.
52
            If a sequence of paths is given, data
53
            will be concatenated on the channel dimension so spatial
54
            dimensions must match.
55
        type: Type of image, such as :attr:`torchio.INTENSITY` or
56
            :attr:`torchio.LABEL`. This will be used by the transforms to
57
            decide whether to apply an operation, or which interpolation to use
58
            when resampling. For example, `preprocessing`_ and `augmentation`_
59
            intensity transforms will only be applied to images with type
60
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
61
            all types, and nearest neighbor interpolation is always used to
62
            resample images with type :attr:`torchio.LABEL`.
63
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
64
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
65
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
66
            :class:`torch.Tensor` or NumPy array with dimensions
67
            :math:`(C, W, H, D)`.
68
        affine: If :attr:`path` is not given, :attr:`affine` must be a
69
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
70
            identity matrix.
71
        check_nans: If ``True``, issues a warning if NaNs are found
72
            in the image. If ``False``, images will not be checked for the
73
            presence of NaNs.
74
        **kwargs: Items that will be added to the image dictionary, e.g.
75
            acquisition parameters.
76
77
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
78
    when needed.
79
80
    Example:
81
        >>> import torchio as tio
82
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
83
        >>> image  # not loaded yet
84
        ScalarImage(path: t1.nii.gz; type: intensity)
85
        >>> times_two = 2 * image.data  # data is loaded and cached here
86
        >>> image
87
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
88
        >>> image.save('doubled_image.nii.gz')
89
90
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
91
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
92
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
93
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
94
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
95
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
96
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
97
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
98
    """
99
    def __init__(
100
            self,
101
            path: Union[TypePath, Sequence[TypePath], None] = None,
102
            type: str = None,
103
            tensor: Optional[TypeData] = None,
104
            affine: Optional[TypeData] = None,
105
            check_nans: bool = False,  # removed by ITK by default
106
            channels_last: bool = False,
107
            **kwargs: Dict[str, Any],
108
            ):
109
        self.check_nans = check_nans
110
        self.channels_last = channels_last
111
112
        if type is None:
113
            warnings.warn(
114
                'Not specifying the image type is deprecated and will be'
115
                ' mandatory in the future. You can probably use tio.ScalarImage'
116
                'or LabelMap instead', DeprecationWarning,
117
            )
118
            type = INTENSITY
119
120
        if path is None and tensor is None:
121
            raise ValueError('A value for path or tensor must be given')
122
        self._loaded = False
123
124
        tensor = self._parse_tensor(tensor)
125
        affine = self._parse_affine(affine)
126
        if tensor is not None:
127
            self.data = tensor
128
            self.affine = affine
129
            self._loaded = True
130
        for key in PROTECTED_KEYS:
131
            if key in kwargs:
132
                message = f'Key "{key}" is reserved. Use a different one'
133
                raise ValueError(message)
134
135
        super().__init__(**kwargs)
136
        self.path = self._parse_path(path)
137
138
        self[PATH] = '' if self.path is None else str(self.path)
139
        self[STEM] = '' if self.path is None else get_stem(self.path)
140
        self[TYPE] = type
141
142
    def __repr__(self):
143
        properties = []
144
        if self._loaded:
145
            properties.extend([
146
                f'shape: {self.shape}',
147
                f'spacing: {self.get_spacing_string()}',
148
                f'orientation: {"".join(self.orientation)}+',
149
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
150
            ])
151
        else:
152
            properties.append(f'path: "{self.path}"')
153
        properties.append(f'dtype: {self.data.type()}')
154
        properties = '; '.join(properties)
155
        string = f'{self.__class__.__name__}({properties})'
156
        return string
157
158
    def __getitem__(self, item):
159
        if item in (DATA, AFFINE):
160
            if item not in self:
161
                self.load()
162
        return super().__getitem__(item)
163
164
    def __array__(self):
165
        return self[DATA].numpy()
166
167
    def __copy__(self):
168
        kwargs = dict(
169
            tensor=self.data,
170
            affine=self.affine,
171
            type=self.type,
172
            path=self.path,
173
        )
174
        for key, value in self.items():
175
            if key in PROTECTED_KEYS: continue
176
            kwargs[key] = value  # should I copy? deepcopy?
177
        return self.__class__(**kwargs)
178
179
    @property
180
    def data(self) -> torch.Tensor:
181
        """Tensor data. Same as :class:`Image.tensor`."""
182
        return self[DATA]
183
184
    @data.setter
185
    def data(self, tensor: TypeData):
186
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
187
188
    @property
189
    def tensor(self) -> torch.Tensor:
190
        """Tensor data. Same as :class:`Image.data`."""
191
        return self.data
192
193
    @property
194
    def affine(self) -> np.ndarray:
195
        """Affine matrix to transform voxel indices into world coordinates."""
196
        return self[AFFINE]
197
198
    @affine.setter
199
    def affine(self, matrix):
200
        self[AFFINE] = self._parse_affine(matrix)
201
202
    @property
203
    def type(self) -> str:
204
        return self[TYPE]
205
206
    @property
207
    def shape(self) -> Tuple[int, int, int, int]:
208
        """Tensor shape as :math:`(C, W, H, D)`."""
209
        return tuple(self.data.shape)
210
211
    @property
212
    def spatial_shape(self) -> TypeTripletInt:
213
        """Tensor spatial shape as :math:`(W, H, D)`."""
214
        return self.shape[1:]
215
216
    def check_is_2d(self) -> None:
217
        if not self.is_2d():
218
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
219
            raise RuntimeError(message)
220
221
    @property
222
    def height(self) -> int:
223
        """Image height, if 2D."""
224
        self.check_is_2d()
225
        return self.spatial_shape[1]
226
227
    @property
228
    def width(self) -> int:
229
        """Image width, if 2D."""
230
        self.check_is_2d()
231
        return self.spatial_shape[0]
232
233
    @property
234
    def orientation(self) -> Tuple[str, str, str]:
235
        """Orientation codes."""
236
        return nib.aff2axcodes(self.affine)
237
238
    @property
239
    def spacing(self) -> Tuple[float, float, float]:
240
        """Voxel spacing in mm."""
241
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
242
        return tuple(spacing)
243
244
    @property
245
    def memory(self) -> float:
246
        """Number of Bytes that the tensor takes in the RAM."""
247
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
248
249
    @property
250
    def bounds(self) -> np.ndarray:
251
        """Position of centers of voxels in smallest and largest coordinates."""
252
        ini = 0, 0, 0
253
        fin = np.array(self.spatial_shape) - 1
254
        point_ini = nib.affines.apply_affine(self.affine, ini)
255
        point_fin = nib.affines.apply_affine(self.affine, fin)
256
        return np.array((point_ini, point_fin))
257
258
    def axis_name_to_index(self, axis: str) -> int:
259
        """Convert an axis name to an axis index.
260
261
        Args:
262
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
263
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
264
                versions and first letters are also valid, as only the first
265
                letter will be used.
266
267
        .. note:: If you are working with animals, you should probably use
268
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
269
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
270
            respectively.
271
272
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
273
            ``'Left'`` and ``'Right'``.
274
        """
275
        # Top and bottom are used for the vertical 2D axis as the use of
276
        # Height vs Horizontal might be ambiguous
277
278
        if not isinstance(axis, str):
279
            raise ValueError('Axis must be a string')
280
        axis = axis[0].upper()
281
282
        # Generally, TorchIO tensors are (C, W, H, D)
283
        if axis in 'TB':  # Top, Bottom
284
            return -2
285
        else:
286
            try:
287
                index = self.orientation.index(axis)
288
            except ValueError:
289
                index = self.orientation.index(self.flip_axis(axis))
290
            # Return negative indices so that it does not matter whether we
291
            # refer to spatial dimensions or not
292
            index = -3 + index
293
            return index
294
295
    # flake8: noqa: E701
296
    @staticmethod
297
    def flip_axis(axis: str) -> str:
298
        if axis == 'R': return 'L'
299
        elif axis == 'L': return 'R'
300
        elif axis == 'A': return 'P'
301
        elif axis == 'P': return 'A'
302
        elif axis == 'I': return 'S'
303
        elif axis == 'S': return 'I'
304
        else:
305
            values = ', '.join('LRPAISTB')
306
            message = f'Axis not understood. Please use one of: {values}'
307
            raise ValueError(message)
308
309
    def get_spacing_string(self) -> str:
310
        strings = [f'{n:.2f}' for n in self.spacing]
311
        string = f'({", ".join(strings)})'
312
        return string
313
314
    def get_bounds(self) -> TypeBounds:
315
        """Get minimum and maximum world coordinates occupied by the image."""
316
        first_index = 3 * (-0.5,)
317
        last_index = np.array(self.spatial_shape) - 0.5
318
        first_point = nib.affines.apply_affine(self.affine, first_index)
319
        last_point = nib.affines.apply_affine(self.affine, last_index)
320
        array = np.array((first_point, last_point))
321
        bounds_x, bounds_y, bounds_z = array.T.tolist()
322
        return bounds_x, bounds_y, bounds_z
323
324
    @staticmethod
325
    def _parse_single_path(
326
            path: TypePath
327
            ) -> Path:
328
        try:
329
            path = Path(path).expanduser()
330
        except TypeError:
331
            message = (
332
                f'Expected type str or Path but found {path} with '
333
                f'{type(path)} instead'
334
            )
335
            raise TypeError(message)
336
        except RuntimeError:
337
            message = (
338
                f'Conversion to path not possible for variable: {path}'
339
            )
340
            raise RuntimeError(message)
341
342
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
343
            raise FileNotFoundError(f'File not found: "{path}"')
344
        return path
345
346
    def _parse_path(
347
            self,
348
            path: Union[TypePath, Sequence[TypePath]]
349
            ) -> Union[Path, List[Path]]:
350
        if path is None:
351
            return None
352
        if isinstance(path, (str, Path)):
353
            return self._parse_single_path(path)
354
        else:
355
            return [self._parse_single_path(p) for p in path]
356
357
    def _parse_tensor(
358
            self,
359
            tensor: TypeData,
360
            none_ok: bool = True,
361
            ) -> torch.Tensor:
362
        if tensor is None:
363
            if none_ok:
364
                return None
365
            else:
366
                raise RuntimeError('Input tensor cannot be None')
367
        if isinstance(tensor, np.ndarray):
368
            tensor = check_uint_to_int(tensor)
369
            tensor = torch.from_numpy(tensor)
370
        elif not isinstance(tensor, torch.Tensor):
371
            message = 'Input tensor must be a PyTorch tensor or NumPy array'
372
            raise TypeError(message)
373
        if tensor.ndim != 4:
374
            raise ValueError('Input tensor must be 4D')
375
        if tensor.dtype == torch.bool:
376
            tensor = tensor.to(torch.uint8)
377
        if self.check_nans and torch.isnan(tensor).any():
378
            warnings.warn(f'NaNs found in tensor', RuntimeWarning)
379
        return tensor
380
381
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
382
        return ensure_4d(tensor)
383
384
    @staticmethod
385
    def _parse_affine(affine: np.ndarray) -> np.ndarray:
386
        if affine is None:
387
            return np.eye(4)
388
        if not isinstance(affine, np.ndarray):
389
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
390
        if affine.shape != (4, 4):
391
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
392
        return affine
393
394
    def load(self) -> None:
395
        r"""Load the image from disk.
396
397
        Returns:
398
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
399
            :math:`4 \times 4` affine matrix to convert voxel indices to world
400
            coordinates.
401
        """
402
        if self._loaded:
403
            return
404
        paths = self.path if isinstance(self.path, list) else [self.path]
405
        tensor, affine = self.read_and_check(paths[0])
406
        tensors = [tensor]
407
        for path in paths[1:]:
408
            new_tensor, new_affine = self.read_and_check(path)
409
            if not np.array_equal(affine, new_affine):
410
                message = (
411
                    'Files have different affine matrices.'
412
                    f'\nMatrix of {paths[0]}:'
413
                    f'\n{affine}'
414
                    f'\nMatrix of {path}:'
415
                    f'\n{new_affine}'
416
                )
417
                warnings.warn(message, RuntimeWarning)
418
            if not tensor.shape[1:] == new_tensor.shape[1:]:
419
                message = (
420
                    f'Files shape do not match, found {tensor.shape}'
421
                    f'and {new_tensor.shape}'
422
                )
423
                RuntimeError(message)
424
            tensors.append(new_tensor)
425
        tensor = torch.cat(tensors)
426
        self.data = tensor
427
        self.affine = affine
428
        self._loaded = True
429
430
    def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
431
        tensor, affine = read_image(path)
432
        tensor = self.parse_tensor_shape(tensor)
433
        if self.channels_last:
434
            tensor = tensor.permute(3, 0, 1, 2)
435
        if self.check_nans and torch.isnan(tensor).any():
436
            warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
437
        return tensor, affine
438
439
    def save(self, path: TypePath, squeeze: bool = True) -> None:
440
        """Save image to disk.
441
442
        Args:
443
            path: String or instance of :class:`pathlib.Path`.
444
            squeeze: If ``True``, singleton dimensions will be removed
445
                before saving.
446
        """
447
        write_image(
448
            self[DATA],
449
            self.affine,
450
            path,
451
            squeeze=squeeze,
452
        )
453
454
    def is_2d(self) -> bool:
455
        return self.shape[-1] == 1
456
457
    def numpy(self) -> np.ndarray:
458
        """Get a NumPy array containing the image data."""
459
        return np.asarray(self)
460
461
    def as_sitk(self, **kwargs) -> sitk.Image:
462
        """Get the image as an instance of :class:`sitk.Image`."""
463
        return nib_to_sitk(self[DATA], self.affine, **kwargs)
464
465
    def as_pil(self) -> ImagePIL:
466
        """Get the image as an instance of :class:`PIL.Image`."""
467
        self.check_is_2d()
468
        return ImagePIL.open(self.path)
469
470
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
471
        """Get image center in RAS+ or LPS+ coordinates.
472
473
        Args:
474
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
475
                the first dimension grows towards the left, etc. Otherwise, the
476
                coordinates will be in RAS+ orientation.
477
        """
478
        size = np.array(self.spatial_shape)
479
        center_index = (size - 1) / 2
480
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
481
        if lps:
482
            return (-r, -a, s)
483
        else:
484
            return (r, a, s)
485
486
    def set_check_nans(self, check_nans: bool) -> None:
487
        self.check_nans = check_nans
488
489
    def plot(self, **kwargs) -> None:
490
        """Plot image."""
491
        if self.is_2d():
492
            self.as_pil().show()
493
        else:
494
            from ..visualization import plot_volume  # avoid circular import
495
            plot_volume(self, **kwargs)
496
497
498
class ScalarImage(Image):
499
    """Image whose pixel values represent scalars.
500
501
    Example:
502
        >>> import torch
503
        >>> import torchio as tio
504
        >>> # Loading from a file
505
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
506
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
507
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
508
        >>> data, affine = image.data, image.affine
509
        >>> affine.shape
510
        (4, 4)
511
        >>> image.data is image[tio.DATA]
512
        True
513
        >>> image.data is image.tensor
514
        True
515
        >>> type(image.data)
516
        torch.Tensor
517
518
    See :class:`~torchio.Image` for more information.
519
    """
520
    def __init__(self, *args, **kwargs):
521
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
522
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
523
        kwargs.update({'type': INTENSITY})
524
        super().__init__(*args, **kwargs)
525
526
527
class LabelMap(Image):
528
    """Image whose pixel values represent categorical labels.
529
530
    Intensity transforms are not applied to these images.
531
532
    Example:
533
        >>> import torch
534
        >>> import torchio as tio
535
        >>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5)
536
        >>> labels = tio.LabelMap('t1_seg.nii.gz')  # loading from a file
537
        >>> tpm = tio.LabelMap(                     # loading from files
538
        ...     'gray_matter.nii.gz',
539
        ...     'white_matter.nii.gz',
540
        ...     'csf.nii.gz',
541
        ... )
542
543
    See :class:`~torchio.Image` for more information.
544
    """
545
    def __init__(self, *args, **kwargs):
546
        if 'type' in kwargs and kwargs['type'] != LABEL:
547
            raise ValueError('Type of LabelMap is always torchio.LABEL')
548
        kwargs.update({'type': LABEL})
549
        super().__init__(*args, **kwargs)
550