Passed
Pull Request — master (#380)
by Fernando
01:10
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 42
Code Lines 34

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 34
nop 8
dl 0
loc 42
rs 6.6666
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
from PIL import Image as ImagePIL
9
import nibabel as nib
10
import SimpleITK as sitk
11
12
from ..utils import (
13
    nib_to_sitk,
14
    get_rotation_and_spacing_from_affine,
15
    get_stem,
16
    ensure_4d,
17
)
18
from ..torchio import (
19
    TypeData,
20
    TypePath,
21
    TypeTripletInt,
22
    TypeTripletFloat,
23
    DATA,
24
    TYPE,
25
    AFFINE,
26
    PATH,
27
    STEM,
28
    INTENSITY,
29
    LABEL,
30
)
31
from .io import read_image, write_image
32
33
34
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
35
TypeBound = Tuple[float, float]
36
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
37
38
39
class Image(dict):
40
    r"""TorchIO image.
41
42
    For information about medical image orientation, check out `NiBabel docs`_,
43
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
44
    `SimpleITK docs`_.
45
46
    Args:
47
        path: Path to a file or sequence of paths to files that can be read by
48
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
49
            DICOM files. If :attr:`tensor` is given, the data in
50
            :attr:`path` will not be read.
51
            If a sequence of paths is given, data
52
            will be concatenated on the channel dimension so spatial
53
            dimensions must match.
54
        type: Type of image, such as :attr:`torchio.INTENSITY` or
55
            :attr:`torchio.LABEL`. This will be used by the transforms to
56
            decide whether to apply an operation, or which interpolation to use
57
            when resampling. For example, `preprocessing`_ and `augmentation`_
58
            intensity transforms will only be applied to images with type
59
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
60
            all types, and nearest neighbor interpolation is always used to
61
            resample images with type :attr:`torchio.LABEL`.
62
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
63
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
64
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
65
            :class:`torch.Tensor` or NumPy array with dimensions
66
            :math:`(C, W, H, D)`.
67
        affine: If :attr:`path` is not given, :attr:`affine` must be a
68
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
69
            identity matrix.
70
        check_nans: If ``True``, issues a warning if NaNs are found
71
            in the image. If ``False``, images will not be checked for the
72
            presence of NaNs.
73
        **kwargs: Items that will be added to the image dictionary, e.g.
74
            acquisition parameters.
75
76
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
77
    when needed.
78
79
    Example:
80
        >>> import torchio as tio
81
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
82
        >>> image  # not loaded yet
83
        ScalarImage(path: t1.nii.gz; type: intensity)
84
        >>> times_two = 2 * image.data  # data is loaded and cached here
85
        >>> image
86
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
87
        >>> image.save('doubled_image.nii.gz')
88
89
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
90
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
91
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
92
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
93
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
94
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
95
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
96
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
97
    """
98
    def __init__(
99
            self,
100
            path: Union[TypePath, Sequence[TypePath], None] = None,
101
            type: str = None,
102
            tensor: Optional[TypeData] = None,
103
            affine: Optional[TypeData] = None,
104
            check_nans: bool = False,  # removed by ITK by default
105
            channels_last: bool = False,
106
            **kwargs: Dict[str, Any],
107
            ):
108
        self.check_nans = check_nans
109
        self.channels_last = channels_last
110
111
        if type is None:
112
            warnings.warn(
113
                'Not specifying the image type is deprecated and will be'
114
                ' mandatory in the future. You can probably use tio.ScalarImage'
115
                'or LabelMap instead', DeprecationWarning,
116
            )
117
            type = INTENSITY
118
119
        if path is None and tensor is None:
120
            raise ValueError('A value for path or tensor must be given')
121
        self._loaded = False
122
123
        tensor = self._parse_tensor(tensor)
124
        affine = self._parse_affine(affine)
125
        if tensor is not None:
126
            self.data = tensor
127
            self.affine = affine
128
            self._loaded = True
129
        for key in PROTECTED_KEYS:
130
            if key in kwargs:
131
                message = f'Key "{key}" is reserved. Use a different one'
132
                raise ValueError(message)
133
134
        super().__init__(**kwargs)
135
        self.path = self._parse_path(path)
136
137
        self[PATH] = '' if self.path is None else str(self.path)
138
        self[STEM] = '' if self.path is None else get_stem(self.path)
139
        self[TYPE] = type
140
141
    def __repr__(self):
142
        properties = []
143
        if self._loaded:
144
            properties.extend([
145
                f'shape: {self.shape}',
146
                f'spacing: {self.get_spacing_string()}',
147
                f'orientation: {"".join(self.orientation)}+',
148
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
149
            ])
150
        else:
151
            properties.append(f'path: "{self.path}"')
152
        properties.append(f'dtype: {self.data.type()}')
153
        properties = '; '.join(properties)
154
        string = f'{self.__class__.__name__}({properties})'
155
        return string
156
157
    def __getitem__(self, item):
158
        if item in (DATA, AFFINE):
159
            if item not in self:
160
                self.load()
161
        return super().__getitem__(item)
162
163
    def __array__(self):
164
        return self[DATA].numpy()
165
166
    def __copy__(self):
167
        kwargs = dict(
168
            tensor=self.data,
169
            affine=self.affine,
170
            type=self.type,
171
            path=self.path,
172
        )
173
        for key, value in self.items():
174
            if key in PROTECTED_KEYS: continue
175
            kwargs[key] = value  # should I copy? deepcopy?
176
        return self.__class__(**kwargs)
177
178
    @property
179
    def data(self) -> torch.Tensor:
180
        """Tensor data. Same as :class:`Image.tensor`."""
181
        return self[DATA]
182
183
    @data.setter
184
    def data(self, tensor: TypeData):
185
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
186
187
    @property
188
    def tensor(self) -> torch.Tensor:
189
        """Tensor data. Same as :class:`Image.data`."""
190
        return self.data
191
192
    @property
193
    def affine(self) -> np.ndarray:
194
        """Affine matrix to transform voxel indices into world coordinates."""
195
        return self[AFFINE]
196
197
    @affine.setter
198
    def affine(self, matrix):
199
        self[AFFINE] = self._parse_affine(matrix)
200
201
    @property
202
    def type(self) -> str:
203
        return self[TYPE]
204
205
    @property
206
    def shape(self) -> Tuple[int, int, int, int]:
207
        """Tensor shape as :math:`(C, W, H, D)`."""
208
        return tuple(self.data.shape)
209
210
    @property
211
    def spatial_shape(self) -> TypeTripletInt:
212
        """Tensor spatial shape as :math:`(W, H, D)`."""
213
        return self.shape[1:]
214
215
    def check_is_2d(self) -> None:
216
        if not self.is_2d():
217
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
218
            raise RuntimeError(message)
219
220
    @property
221
    def height(self) -> int:
222
        """Image height, if 2D."""
223
        self.check_is_2d()
224
        return self.spatial_shape[1]
225
226
    @property
227
    def width(self) -> int:
228
        """Image width, if 2D."""
229
        self.check_is_2d()
230
        return self.spatial_shape[0]
231
232
    @property
233
    def orientation(self) -> Tuple[str, str, str]:
234
        """Orientation codes."""
235
        return nib.aff2axcodes(self.affine)
236
237
    @property
238
    def spacing(self) -> Tuple[float, float, float]:
239
        """Voxel spacing in mm."""
240
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
241
        return tuple(spacing)
242
243
    @property
244
    def memory(self) -> float:
245
        """Number of Bytes that the tensor takes in the RAM."""
246
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
247
248
    @property
249
    def bounds(self) -> np.ndarray:
250
        """Position of centers of voxels in smallest and largest coordinates."""
251
        ini = 0, 0, 0
252
        fin = np.array(self.spatial_shape) - 1
253
        point_ini = nib.affines.apply_affine(self.affine, ini)
254
        point_fin = nib.affines.apply_affine(self.affine, fin)
255
        return np.array((point_ini, point_fin))
256
257
    def axis_name_to_index(self, axis: str) -> int:
258
        """Convert an axis name to an axis index.
259
260
        Args:
261
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
262
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
263
                versions and first letters are also valid, as only the first
264
                letter will be used.
265
266
        .. note:: If you are working with animals, you should probably use
267
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
268
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
269
            respectively.
270
271
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
272
            ``'Left'`` and ``'Right'``.
273
        """
274
        # Top and bottom are used for the vertical 2D axis as the use of
275
        # Height vs Horizontal might be ambiguous
276
277
        if not isinstance(axis, str):
278
            raise ValueError('Axis must be a string')
279
        axis = axis[0].upper()
280
281
        # Generally, TorchIO tensors are (C, W, H, D)
282
        if axis in 'TB':  # Top, Bottom
283
            return -2
284
        else:
285
            try:
286
                index = self.orientation.index(axis)
287
            except ValueError:
288
                index = self.orientation.index(self.flip_axis(axis))
289
            # Return negative indices so that it does not matter whether we
290
            # refer to spatial dimensions or not
291
            index = -3 + index
292
            return index
293
294
    # flake8: noqa: E701
295
    @staticmethod
296
    def flip_axis(axis: str) -> str:
297
        if axis == 'R': return 'L'
298
        elif axis == 'L': return 'R'
299
        elif axis == 'A': return 'P'
300
        elif axis == 'P': return 'A'
301
        elif axis == 'I': return 'S'
302
        elif axis == 'S': return 'I'
303
        else:
304
            values = ', '.join('LRPAISTB')
305
            message = f'Axis not understood. Please use one of: {values}'
306
            raise ValueError(message)
307
308
    def get_spacing_string(self) -> str:
309
        strings = [f'{n:.2f}' for n in self.spacing]
310
        string = f'({", ".join(strings)})'
311
        return string
312
313
    def get_bounds(self) -> TypeBounds:
314
        """Get minimum and maximum world coordinates occupied by the image."""
315
        first_index = 3 * (-0.5,)
316
        last_index = np.array(self.spatial_shape) - 0.5
317
        first_point = nib.affines.apply_affine(self.affine, first_index)
318
        last_point = nib.affines.apply_affine(self.affine, last_index)
319
        array = np.array((first_point, last_point))
320
        bounds_x, bounds_y, bounds_z = array.T.tolist()
321
        return bounds_x, bounds_y, bounds_z
322
323
    @staticmethod
324
    def _parse_single_path(
325
            path: TypePath
326
            ) -> Path:
327
        try:
328
            path = Path(path).expanduser()
329
        except TypeError:
330
            message = (
331
                f'Expected type str or Path but found {path} with '
332
                f'{type(path)} instead'
333
            )
334
            raise TypeError(message)
335
        except RuntimeError:
336
            message = (
337
                f'Conversion to path not possible for variable: {path}'
338
            )
339
            raise RuntimeError(message)
340
341
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
342
            raise FileNotFoundError(f'File not found: "{path}"')
343
        return path
344
345
    def _parse_path(
346
            self,
347
            path: Union[TypePath, Sequence[TypePath]]
348
            ) -> Union[Path, List[Path]]:
349
        if path is None:
350
            return None
351
        if isinstance(path, (str, Path)):
352
            return self._parse_single_path(path)
353
        else:
354
            return [self._parse_single_path(p) for p in path]
355
356
    def _parse_tensor(
357
            self,
358
            tensor: TypeData,
359
            none_ok: bool = True,
360
            ) -> torch.Tensor:
361
        if tensor is None:
362
            if none_ok:
363
                return None
364
            else:
365
                raise RuntimeError('Input tensor cannot be None')
366
        if isinstance(tensor, np.ndarray):
367
            try:
368
                tensor = torch.from_numpy(tensor)
369
            except TypeError:
370
                if tensor.dtype == np.uint16:
371
                    tensor = torch.from_numpy(tensor.astype(np.int32))
372
                elif tensor.dtype == np.uint32:
373
                    tensor = torch.from_numpy(tensor.astype(np.int64))
374
                else:
375
                    raise
376
        elif not isinstance(tensor, torch.Tensor):
377
            message = 'Input tensor must be a PyTorch tensor or NumPy array'
378
            raise TypeError(message)
379
        if tensor.ndim != 4:
380
            raise ValueError('Input tensor must be 4D')
381
        if tensor.dtype == torch.bool:
382
            tensor = tensor.to(torch.uint8)
383
        if self.check_nans and torch.isnan(tensor).any():
384
            warnings.warn(f'NaNs found in tensor', RuntimeWarning)
385
        return tensor
386
387
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
388
        return ensure_4d(tensor)
389
390
    @staticmethod
391
    def _parse_affine(affine: np.ndarray) -> np.ndarray:
392
        if affine is None:
393
            return np.eye(4)
394
        if not isinstance(affine, np.ndarray):
395
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
396
        if affine.shape != (4, 4):
397
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
398
        return affine
399
400
    def load(self) -> None:
401
        r"""Load the image from disk.
402
403
        Returns:
404
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
405
            :math:`4 \times 4` affine matrix to convert voxel indices to world
406
            coordinates.
407
        """
408
        if self._loaded:
409
            return
410
        paths = self.path if isinstance(self.path, list) else [self.path]
411
        tensor, affine = self.read_and_check(paths[0])
412
        tensors = [tensor]
413
        for path in paths[1:]:
414
            new_tensor, new_affine = self.read_and_check(path)
415
            if not np.array_equal(affine, new_affine):
416
                message = (
417
                    'Files have different affine matrices.'
418
                    f'\nMatrix of {paths[0]}:'
419
                    f'\n{affine}'
420
                    f'\nMatrix of {path}:'
421
                    f'\n{new_affine}'
422
                )
423
                warnings.warn(message, RuntimeWarning)
424
            if not tensor.shape[1:] == new_tensor.shape[1:]:
425
                message = (
426
                    f'Files shape do not match, found {tensor.shape}'
427
                    f'and {new_tensor.shape}'
428
                )
429
                RuntimeError(message)
430
            tensors.append(new_tensor)
431
        tensor = torch.cat(tensors)
432
        self.data = tensor
433
        self.affine = affine
434
        self._loaded = True
435
436
    def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
437
        tensor, affine = read_image(path)
438
        tensor = self.parse_tensor_shape(tensor)
439
        if self.channels_last:
440
            tensor = tensor.permute(3, 0, 1, 2)
441
        if self.check_nans and torch.isnan(tensor).any():
442
            warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
443
        return tensor, affine
444
445
    def save(self, path: TypePath, squeeze: bool = True) -> None:
446
        """Save image to disk.
447
448
        Args:
449
            path: String or instance of :class:`pathlib.Path`.
450
            squeeze: If ``True``, singleton dimensions will be removed
451
                before saving.
452
        """
453
        write_image(
454
            self[DATA],
455
            self.affine,
456
            path,
457
            squeeze=squeeze,
458
        )
459
460
    def is_2d(self) -> bool:
461
        return self.shape[-1] == 1
462
463
    def numpy(self) -> np.ndarray:
464
        """Get a NumPy array containing the image data."""
465
        return np.asarray(self)
466
467
    def as_sitk(self, **kwargs) -> sitk.Image:
468
        """Get the image as an instance of :class:`sitk.Image`."""
469
        return nib_to_sitk(self[DATA], self.affine, **kwargs)
470
471
    def as_pil(self) -> ImagePIL:
472
        """Get the image as an instance of :class:`PIL.Image`."""
473
        self.check_is_2d()
474
        return ImagePIL.open(self.path)
475
476
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
477
        """Get image center in RAS+ or LPS+ coordinates.
478
479
        Args:
480
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
481
                the first dimension grows towards the left, etc. Otherwise, the
482
                coordinates will be in RAS+ orientation.
483
        """
484
        size = np.array(self.spatial_shape)
485
        center_index = (size - 1) / 2
486
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
487
        if lps:
488
            return (-r, -a, s)
489
        else:
490
            return (r, a, s)
491
492
    def set_check_nans(self, check_nans: bool) -> None:
493
        self.check_nans = check_nans
494
495
    def plot(self, **kwargs) -> None:
496
        """Plot image."""
497
        if self.is_2d():
498
            self.as_pil().show()
499
        else:
500
            from ..visualization import plot_volume  # avoid circular import
501
            plot_volume(self, **kwargs)
502
503
504
class ScalarImage(Image):
505
    """Image whose pixel values represent scalars.
506
507
    Example:
508
        >>> import torch
509
        >>> import torchio as tio
510
        >>> # Loading from a file
511
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
512
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
513
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
514
        >>> data, affine = image.data, image.affine
515
        >>> affine.shape
516
        (4, 4)
517
        >>> image.data is image[tio.DATA]
518
        True
519
        >>> image.data is image.tensor
520
        True
521
        >>> type(image.data)
522
        torch.Tensor
523
524
    See :class:`~torchio.Image` for more information.
525
    """
526
    def __init__(self, *args, **kwargs):
527
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
528
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
529
        kwargs.update({'type': INTENSITY})
530
        super().__init__(*args, **kwargs)
531
532
533
class LabelMap(Image):
534
    """Image whose pixel values represent categorical labels.
535
536
    Intensity transforms are not applied to these images.
537
538
    Example:
539
        >>> import torch
540
        >>> import torchio as tio
541
        >>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5)
542
        >>> labels = tio.LabelMap('t1_seg.nii.gz')  # loading from a file
543
        >>> tpm = tio.LabelMap(                     # loading from files
544
        ...     'gray_matter.nii.gz',
545
        ...     'white_matter.nii.gz',
546
        ...     'csf.nii.gz',
547
        ... )
548
549
    See :class:`~torchio.Image` for more information.
550
    """
551
    def __init__(self, *args, **kwargs):
552
        if 'type' in kwargs and kwargs['type'] != LABEL:
553
            raise ValueError('Type of LabelMap is always torchio.LABEL')
554
        kwargs.update({'type': LABEL})
555
        super().__init__(*args, **kwargs)
556