Passed
Push — master ( e921a7...82a6cd )
by Fernando
01:10
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 42
Code Lines 34

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 34
nop 8
dl 0
loc 42
rs 6.6666
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
from PIL import Image as ImagePIL
11
from deprecated import deprecated
12
13
from ..utils import get_stem
14
from ..typing import TypeData, TypePath, TypeTripletInt, TypeTripletFloat
15
from ..constants import DATA, TYPE, AFFINE, PATH, STEM, INTENSITY, LABEL
16
from .io import (
17
    ensure_4d,
18
    read_image,
19
    write_image,
20
    nib_to_sitk,
21
    check_uint_to_int,
22
    get_rotation_and_spacing_from_affine,
23
)
24
25
26
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
27
TypeBound = Tuple[float, float]
28
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
29
30
deprecation_message = (
31
    'Setting the image data with the property setter is deprecated. Use the'
32
    ' set_data() method instead'
33
)
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file or sequence of paths to files that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :attr:`tensor` is given, the data in
47
            :attr:`path` will not be read.
48
            If a sequence of paths is given, data
49
            will be concatenated on the channel dimension so spatial
50
            dimensions must match.
51
        type: Type of image, such as :attr:`torchio.INTENSITY` or
52
            :attr:`torchio.LABEL`. This will be used by the transforms to
53
            decide whether to apply an operation, or which interpolation to use
54
            when resampling. For example, `preprocessing`_ and `augmentation`_
55
            intensity transforms will only be applied to images with type
56
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
57
            all types, and nearest neighbor interpolation is always used to
58
            resample images with type :attr:`torchio.LABEL`.
59
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
60
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
61
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
62
            :class:`torch.Tensor` or NumPy array with dimensions
63
            :math:`(C, W, H, D)`.
64
        affine: If :attr:`path` is not given, :attr:`affine` must be a
65
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
66
            identity matrix.
67
        check_nans: If ``True``, issues a warning if NaNs are found
68
            in the image. If ``False``, images will not be checked for the
69
            presence of NaNs.
70
        **kwargs: Items that will be added to the image dictionary, e.g.
71
            acquisition parameters.
72
73
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
74
    when needed.
75
76
    Example:
77
        >>> import torchio as tio
78
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
79
        >>> image  # not loaded yet
80
        ScalarImage(path: t1.nii.gz; type: intensity)
81
        >>> times_two = 2 * image.data  # data is loaded and cached here
82
        >>> image
83
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
84
        >>> image.save('doubled_image.nii.gz')
85
86
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
87
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
88
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
89
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
90
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
91
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
92
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
93
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
94
    """
95
    def __init__(
96
            self,
97
            path: Union[TypePath, Sequence[TypePath], None] = None,
98
            type: str = None,
99
            tensor: Optional[TypeData] = None,
100
            affine: Optional[TypeData] = None,
101
            check_nans: bool = False,  # removed by ITK by default
102
            channels_last: bool = False,
103
            **kwargs: Dict[str, Any],
104
            ):
105
        self.check_nans = check_nans
106
        self.channels_last = channels_last
107
108
        if type is None:
109
            warnings.warn(
110
                'Not specifying the image type is deprecated and will be'
111
                ' mandatory in the future. You can probably use tio.ScalarImage'
112
                ' or tio.LabelMap instead', DeprecationWarning,
113
            )
114
            type = INTENSITY
115
116
        if path is None and tensor is None:
117
            raise ValueError('A value for path or tensor must be given')
118
        self._loaded = False
119
120
        tensor = self._parse_tensor(tensor)
121
        affine = self._parse_affine(affine)
122
        if tensor is not None:
123
            self.set_data(tensor)
124
            self.affine = affine
125
            self._loaded = True
126
        for key in PROTECTED_KEYS:
127
            if key in kwargs:
128
                message = f'Key "{key}" is reserved. Use a different one'
129
                raise ValueError(message)
130
131
        super().__init__(**kwargs)
132
        self.path = self._parse_path(path)
133
134
        self[PATH] = '' if self.path is None else str(self.path)
135
        self[STEM] = '' if self.path is None else get_stem(self.path)
136
        self[TYPE] = type
137
138
    def __repr__(self):
139
        properties = []
140
        if self._loaded:
141
            properties.extend([
142
                f'shape: {self.shape}',
143
                f'spacing: {self.get_spacing_string()}',
144
                f'orientation: {"".join(self.orientation)}+',
145
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
146
            ])
147
        else:
148
            properties.append(f'path: "{self.path}"')
149
        if self._loaded:
150
            properties.append(f'dtype: {self.data.type()}')
151
        properties = '; '.join(properties)
152
        string = f'{self.__class__.__name__}({properties})'
153
        return string
154
155
    def __getitem__(self, item):
156
        if item in (DATA, AFFINE):
157
            if item not in self:
158
                self.load()
159
        return super().__getitem__(item)
160
161
    def __array__(self):
162
        return self.data.numpy()
163
164
    def __copy__(self):
165
        kwargs = dict(
166
            tensor=self.data,
167
            affine=self.affine,
168
            type=self.type,
169
            path=self.path,
170
        )
171
        for key, value in self.items():
172
            if key in PROTECTED_KEYS: continue
173
            kwargs[key] = value  # should I copy? deepcopy?
174
        return self.__class__(**kwargs)
175
176
    @property
177
    def data(self) -> torch.Tensor:
178
        """Tensor data. Same as :class:`Image.tensor`."""
179
        return self[DATA]
180
181
    @data.setter
182
    @deprecated(version='0.18.16', reason=deprecation_message)
183
    def data(self, tensor: TypeData):
184
        self.set_data(tensor)
185
186
    def set_data(self, tensor: TypeData):
187
        """Store a 4D tensor in the :attr:`data` key and attribute.
188
189
        Args:
190
            tensor: 4D tensor with dimensions :math:`(C, W, H, D)`.
191
        """
192
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
193
194
    @property
195
    def tensor(self) -> torch.Tensor:
196
        """Tensor data. Same as :class:`Image.data`."""
197
        return self.data
198
199
    @property
200
    def affine(self) -> np.ndarray:
201
        """Affine matrix to transform voxel indices into world coordinates."""
202
        return self[AFFINE]
203
204
    @affine.setter
205
    def affine(self, matrix):
206
        self[AFFINE] = self._parse_affine(matrix)
207
208
    @property
209
    def type(self) -> str:
210
        return self[TYPE]
211
212
    @property
213
    def shape(self) -> Tuple[int, int, int, int]:
214
        """Tensor shape as :math:`(C, W, H, D)`."""
215
        return tuple(self.data.shape)
216
217
    @property
218
    def spatial_shape(self) -> TypeTripletInt:
219
        """Tensor spatial shape as :math:`(W, H, D)`."""
220
        return self.shape[1:]
221
222
    def check_is_2d(self) -> None:
223
        if not self.is_2d():
224
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
225
            raise RuntimeError(message)
226
227
    @property
228
    def height(self) -> int:
229
        """Image height, if 2D."""
230
        self.check_is_2d()
231
        return self.spatial_shape[1]
232
233
    @property
234
    def width(self) -> int:
235
        """Image width, if 2D."""
236
        self.check_is_2d()
237
        return self.spatial_shape[0]
238
239
    @property
240
    def orientation(self) -> Tuple[str, str, str]:
241
        """Orientation codes."""
242
        return nib.aff2axcodes(self.affine)
243
244
    @property
245
    def spacing(self) -> Tuple[float, float, float]:
246
        """Voxel spacing in mm."""
247
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
248
        return tuple(spacing)
249
250
    @property
251
    def memory(self) -> float:
252
        """Number of Bytes that the tensor takes in the RAM."""
253
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
254
255
    @property
256
    def bounds(self) -> np.ndarray:
257
        """Position of centers of voxels in smallest and largest coordinates."""
258
        ini = 0, 0, 0
259
        fin = np.array(self.spatial_shape) - 1
260
        point_ini = nib.affines.apply_affine(self.affine, ini)
261
        point_fin = nib.affines.apply_affine(self.affine, fin)
262
        return np.array((point_ini, point_fin))
263
264
    def axis_name_to_index(self, axis: str) -> int:
265
        """Convert an axis name to an axis index.
266
267
        Args:
268
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
269
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
270
                versions and first letters are also valid, as only the first
271
                letter will be used.
272
273
        .. note:: If you are working with animals, you should probably use
274
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
275
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
276
            respectively.
277
278
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
279
            ``'Left'`` and ``'Right'``.
280
        """
281
        # Top and bottom are used for the vertical 2D axis as the use of
282
        # Height vs Horizontal might be ambiguous
283
284
        if not isinstance(axis, str):
285
            raise ValueError('Axis must be a string')
286
        axis = axis[0].upper()
287
288
        # Generally, TorchIO tensors are (C, W, H, D)
289
        if axis in 'TB':  # Top, Bottom
290
            return -2
291
        else:
292
            try:
293
                index = self.orientation.index(axis)
294
            except ValueError:
295
                index = self.orientation.index(self.flip_axis(axis))
296
            # Return negative indices so that it does not matter whether we
297
            # refer to spatial dimensions or not
298
            index = -3 + index
299
            return index
300
301
    # flake8: noqa: E701
302
    @staticmethod
303
    def flip_axis(axis: str) -> str:
304
        if axis == 'R': flipped_axis = 'L'
305
        elif axis == 'L': flipped_axis = 'R'
306
        elif axis == 'A': flipped_axis = 'P'
307
        elif axis == 'P': flipped_axis = 'A'
308
        elif axis == 'I': flipped_axis = 'S'
309
        elif axis == 'S': flipped_axis = 'I'
310
        elif axis == 'T': flipped_axis = 'B'
311
        elif axis == 'B': flipped_axis = 'T'
312
        else:
313
            values = ', '.join('LRPAISTB')
314
            message = f'Axis not understood. Please use one of: {values}'
315
            raise ValueError(message)
316
        return flipped_axis
317
318
    def get_spacing_string(self) -> str:
319
        strings = [f'{n:.2f}' for n in self.spacing]
320
        string = f'({", ".join(strings)})'
321
        return string
322
323
    def get_bounds(self) -> TypeBounds:
324
        """Get minimum and maximum world coordinates occupied by the image."""
325
        first_index = 3 * (-0.5,)
326
        last_index = np.array(self.spatial_shape) - 0.5
327
        first_point = nib.affines.apply_affine(self.affine, first_index)
328
        last_point = nib.affines.apply_affine(self.affine, last_index)
329
        array = np.array((first_point, last_point))
330
        bounds_x, bounds_y, bounds_z = array.T.tolist()
331
        return bounds_x, bounds_y, bounds_z
332
333
    @staticmethod
334
    def _parse_single_path(
335
            path: TypePath
336
            ) -> Path:
337
        try:
338
            path = Path(path).expanduser()
339
        except TypeError:
340
            message = (
341
                f'Expected type str or Path but found {path} with type'
342
                f' {type(path)} instead'
343
            )
344
            raise TypeError(message)
345
        except RuntimeError:
346
            message = (
347
                f'Conversion to path not possible for variable: {path}'
348
            )
349
            raise RuntimeError(message)
350
351
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
352
            raise FileNotFoundError(f'File not found: "{path}"')
353
        return path
354
355
    def _parse_path(
356
            self,
357
            path: Union[TypePath, Sequence[TypePath]]
358
            ) -> Optional[Union[Path, List[Path]]]:
359
        if path is None:
360
            return None
361
        if isinstance(path, (str, Path)):
362
            return self._parse_single_path(path)
363
        else:
364
            return [self._parse_single_path(p) for p in path]
365
366
    def _parse_tensor(
367
            self,
368
            tensor: TypeData,
369
            none_ok: bool = True,
370
            ) -> torch.Tensor:
371
        if tensor is None:
372
            if none_ok:
373
                return None
374
            else:
375
                raise RuntimeError('Input tensor cannot be None')
376
        if isinstance(tensor, np.ndarray):
377
            tensor = check_uint_to_int(tensor)
378
            tensor = torch.from_numpy(tensor)
379
        elif not isinstance(tensor, torch.Tensor):
380
            message = 'Input tensor must be a PyTorch tensor or NumPy array'
381
            raise TypeError(message)
382
        if tensor.ndim != 4:
383
            raise ValueError('Input tensor must be 4D')
384
        if tensor.dtype == torch.bool:
385
            tensor = tensor.to(torch.uint8)
386
        if self.check_nans and torch.isnan(tensor).any():
387
            warnings.warn(f'NaNs found in tensor', RuntimeWarning)
388
        return tensor
389
390
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
391
        return ensure_4d(tensor)
392
393
    @staticmethod
394
    def _parse_affine(affine: TypeData) -> np.ndarray:
395
        if affine is None:
396
            return np.eye(4)
397
        if isinstance(affine, torch.Tensor):
398
            affine = affine.numpy()
399
        if not isinstance(affine, np.ndarray):
400
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
401
        if affine.shape != (4, 4):
402
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
403
        return affine
404
405
    def load(self) -> None:
406
        r"""Load the image from disk.
407
408
        Returns:
409
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
410
            :math:`4 \times 4` affine matrix to convert voxel indices to world
411
            coordinates.
412
        """
413
        if self._loaded:
414
            return
415
        paths = self.path if isinstance(self.path, list) else [self.path]
416
        tensor, affine = self.read_and_check(paths[0])
417
        tensors = [tensor]
418
        for path in paths[1:]:
419
            new_tensor, new_affine = self.read_and_check(path)
420
            if not np.array_equal(affine, new_affine):
421
                message = (
422
                    'Files have different affine matrices.'
423
                    f'\nMatrix of {paths[0]}:'
424
                    f'\n{affine}'
425
                    f'\nMatrix of {path}:'
426
                    f'\n{new_affine}'
427
                )
428
                warnings.warn(message, RuntimeWarning)
429
            if not tensor.shape[1:] == new_tensor.shape[1:]:
430
                message = (
431
                    f'Files shape do not match, found {tensor.shape}'
432
                    f'and {new_tensor.shape}'
433
                )
434
                RuntimeError(message)
435
            tensors.append(new_tensor)
436
        tensor = torch.cat(tensors)
437
        self.set_data(tensor)
438
        self.affine = affine
439
        self._loaded = True
440
441
    def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
442
        tensor, affine = read_image(path)
443
        tensor = self.parse_tensor_shape(tensor)
444
        if self.channels_last:
445
            tensor = tensor.permute(3, 0, 1, 2)
446
        if self.check_nans and torch.isnan(tensor).any():
447
            warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
448
        return tensor, affine
449
450
    def save(self, path: TypePath, squeeze: bool = True) -> None:
451
        """Save image to disk.
452
453
        Args:
454
            path: String or instance of :class:`pathlib.Path`.
455
            squeeze: If ``True``, singleton dimensions will be removed
456
                before saving.
457
        """
458
        write_image(
459
            self.data,
460
            self.affine,
461
            path,
462
            squeeze=squeeze,
463
        )
464
465
    def is_2d(self) -> bool:
466
        return self.shape[-1] == 1
467
468
    def numpy(self) -> np.ndarray:
469
        """Get a NumPy array containing the image data."""
470
        return np.asarray(self)
471
472
    def as_sitk(self, **kwargs) -> sitk.Image:
473
        """Get the image as an instance of :class:`sitk.Image`."""
474
        return nib_to_sitk(self.data, self.affine, **kwargs)
475
476
    def as_pil(self) -> ImagePIL:
477
        """Get the image as an instance of :class:`PIL.Image`.
478
479
        .. note:: Values will be clamped to 0-255 and cast to uint8.
480
        """
481
        self.check_is_2d()
482
        tensor = self.data
483
        if len(tensor) == 1:
484
            tensor = torch.cat(3 * [tensor])
485
        if len(tensor) != 3:
486
            raise RuntimeError('The image must have 1 or 3 channels')
487
        tensor = tensor.permute(3, 1, 2, 0)[0]
488
        array = tensor.clamp(0, 255).numpy()
489
        return ImagePIL.fromarray(array.astype(np.uint8))
490
491
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
492
        """Get image center in RAS+ or LPS+ coordinates.
493
494
        Args:
495
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
496
                the first dimension grows towards the left, etc. Otherwise, the
497
                coordinates will be in RAS+ orientation.
498
        """
499
        size = np.array(self.spatial_shape)
500
        center_index = (size - 1) / 2
501
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
502
        if lps:
503
            return (-r, -a, s)
504
        else:
505
            return (r, a, s)
506
507
    def set_check_nans(self, check_nans: bool) -> None:
508
        self.check_nans = check_nans
509
510
    def plot(self, **kwargs) -> None:
511
        """Plot image."""
512
        if self.is_2d():
513
            self.as_pil().show()
514
        else:
515
            from ..visualization import plot_volume  # avoid circular import
516
            plot_volume(self, **kwargs)
517
518
519
class ScalarImage(Image):
520
    """Image whose pixel values represent scalars.
521
522
    Example:
523
        >>> import torch
524
        >>> import torchio as tio
525
        >>> # Loading from a file
526
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
527
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
528
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
529
        >>> data, affine = image.data, image.affine
530
        >>> affine.shape
531
        (4, 4)
532
        >>> image.data is image[tio.DATA]
533
        True
534
        >>> image.data is image.tensor
535
        True
536
        >>> type(image.data)
537
        torch.Tensor
538
539
    See :class:`~torchio.Image` for more information.
540
    """
541
    def __init__(self, *args, **kwargs):
542
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
543
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
544
        kwargs.update({'type': INTENSITY})
545
        super().__init__(*args, **kwargs)
546
547
548
class LabelMap(Image):
549
    """Image whose pixel values represent categorical labels.
550
551
    Intensity transforms are not applied to these images.
552
553
    Example:
554
        >>> import torch
555
        >>> import torchio as tio
556
        >>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5)
557
        >>> labels = tio.LabelMap('t1_seg.nii.gz')  # loading from a file
558
        >>> tpm = tio.LabelMap(                     # loading from files
559
        ...     'gray_matter.nii.gz',
560
        ...     'white_matter.nii.gz',
561
        ...     'csf.nii.gz',
562
        ... )
563
564
    See :class:`~torchio.Image` for more information.
565
    """
566
    def __init__(self, *args, **kwargs):
567
        if 'type' in kwargs and kwargs['type'] != LABEL:
568
            raise ValueError('Type of LabelMap is always torchio.LABEL')
569
        kwargs.update({'type': LABEL})
570
        super().__init__(*args, **kwargs)
571