Passed
Pull Request — master (#380)
by Fernando
01:25
created

torchio.data.image.Image._parse_tensor()   D

Complexity

Conditions 12

Size

Total Lines 30
Code Lines 27

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 12
eloc 27
nop 3
dl 0
loc 30
rs 4.8
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like torchio.data.image.Image._parse_tensor() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
from PIL import Image as ImagePIL
9
import nibabel as nib
10
import SimpleITK as sitk
11
12
from ..utils import (
13
    nib_to_sitk,
14
    get_rotation_and_spacing_from_affine,
15
    get_stem,
16
    ensure_4d,
17
)
18
from ..torchio import (
19
    TypeData,
20
    TypePath,
21
    TypeTripletInt,
22
    TypeTripletFloat,
23
    DATA,
24
    TYPE,
25
    AFFINE,
26
    PATH,
27
    STEM,
28
    INTENSITY,
29
    LABEL,
30
)
31
from .io import read_image, write_image
32
33
34
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
35
TypeBound = Tuple[float, float]
36
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
37
38
39
class Image(dict):
40
    r"""TorchIO image.
41
42
    For information about medical image orientation, check out `NiBabel docs`_,
43
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
44
    `SimpleITK docs`_.
45
46
    Args:
47
        path: Path to a file or sequence of paths to files that can be read by
48
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
49
            DICOM files. If :attr:`tensor` is given, the data in
50
            :attr:`path` will not be read.
51
            If a sequence of paths is given, data
52
            will be concatenated on the channel dimension so spatial
53
            dimensions must match.
54
        type: Type of image, such as :attr:`torchio.INTENSITY` or
55
            :attr:`torchio.LABEL`. This will be used by the transforms to
56
            decide whether to apply an operation, or which interpolation to use
57
            when resampling. For example, `preprocessing`_ and `augmentation`_
58
            intensity transforms will only be applied to images with type
59
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
60
            all types, and nearest neighbor interpolation is always used to
61
            resample images with type :attr:`torchio.LABEL`.
62
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
63
            :class:`~torchio.data.sampler.weighted.WeightedSampler`.
64
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
65
            :class:`torch.Tensor` or NumPy array with dimensions
66
            :math:`(C, W, H, D)`.
67
        affine: If :attr:`path` is not given, :attr:`affine` must be a
68
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
69
            identity matrix.
70
        check_nans: If ``True``, issues a warning if NaNs are found
71
            in the image. If ``False``, images will not be checked for the
72
            presence of NaNs.
73
        **kwargs: Items that will be added to the image dictionary, e.g.
74
            acquisition parameters.
75
76
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
77
    when needed.
78
79
    Example:
80
        >>> import torchio as tio
81
        >>> image = tio.ScalarImage('t1.nii.gz')  # subclass of Image
82
        >>> image  # not loaded yet
83
        ScalarImage(path: t1.nii.gz; type: intensity)
84
        >>> times_two = 2 * image.data  # data is loaded and cached here
85
        >>> image
86
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
87
        >>> image.save('doubled_image.nii.gz')
88
89
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
90
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
91
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
92
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
93
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
94
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
95
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
96
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
97
    """
98
    def __init__(
99
            self,
100
            path: Union[TypePath, Sequence[TypePath], None] = None,
101
            type: str = None,
102
            tensor: Optional[TypeData] = None,
103
            affine: Optional[TypeData] = None,
104
            check_nans: bool = False,  # removed by ITK by default
105
            channels_last: bool = False,
106
            **kwargs: Dict[str, Any],
107
            ):
108
        self.check_nans = check_nans
109
        self.channels_last = channels_last
110
111
        if type is None:
112
            warnings.warn(
113
                'Not specifying the image type is deprecated and will be'
114
                ' mandatory in the future. You can probably use tio.ScalarImage'
115
                'or LabelMap instead', DeprecationWarning,
116
            )
117
            type = INTENSITY
118
119
        if path is None and tensor is None:
120
            raise ValueError('A value for path or tensor must be given')
121
        self._loaded = False
122
123
        tensor = self._parse_tensor(tensor)
124
        affine = self._parse_affine(affine)
125
        if tensor is not None:
126
            self.data = tensor
127
            self.affine = affine
128
            self._loaded = True
129
        for key in PROTECTED_KEYS:
130
            if key in kwargs:
131
                message = f'Key "{key}" is reserved. Use a different one'
132
                raise ValueError(message)
133
134
        super().__init__(**kwargs)
135
        self.path = self._parse_path(path)
136
137
        self[PATH] = '' if self.path is None else str(self.path)
138
        self[STEM] = '' if self.path is None else get_stem(self.path)
139
        self[TYPE] = type
140
141
    def __repr__(self):
142
        properties = []
143
        if self._loaded:
144
            properties.extend([
145
                f'shape: {self.shape}',
146
                f'spacing: {self.get_spacing_string()}',
147
                f'orientation: {"".join(self.orientation)}+',
148
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
149
            ])
150
        else:
151
            properties.append(f'path: "{self.path}"')
152
        properties.append(f'dtype: {self.data.type()}')
153
        properties = '; '.join(properties)
154
        string = f'{self.__class__.__name__}({properties})'
155
        return string
156
157
    def __getitem__(self, item):
158
        if item in (DATA, AFFINE):
159
            if item not in self:
160
                self.load()
161
        return super().__getitem__(item)
162
163
    def __array__(self):
164
        return self[DATA].numpy()
165
166
    def __copy__(self):
167
        kwargs = dict(
168
            tensor=self.data,
169
            affine=self.affine,
170
            type=self.type,
171
            path=self.path,
172
        )
173
        for key, value in self.items():
174
            if key in PROTECTED_KEYS: continue
175
            kwargs[key] = value  # should I copy? deepcopy?
176
        return self.__class__(**kwargs)
177
178
    @property
179
    def data(self) -> torch.Tensor:
180
        """Tensor data. Same as :class:`Image.tensor`."""
181
        return self[DATA]
182
183
    @data.setter
184
    def data(self, tensor: TypeData):
185
        self[DATA] = self._parse_tensor(tensor, none_ok=False)
186
187
    @property
188
    def tensor(self) -> torch.Tensor:
189
        """Tensor data. Same as :class:`Image.data`."""
190
        return self.data
191
192
    @property
193
    def affine(self) -> np.ndarray:
194
        """Affine matrix to transform voxel indices into world coordinates."""
195
        return self[AFFINE]
196
197
    @affine.setter
198
    def affine(self, matrix):
199
        self[AFFINE] = self._parse_affine(matrix)
200
201
    @property
202
    def type(self) -> str:
203
        return self[TYPE]
204
205
    @property
206
    def shape(self) -> Tuple[int, int, int, int]:
207
        """Tensor shape as :math:`(C, W, H, D)`."""
208
        return tuple(self.data.shape)
209
210
    @property
211
    def spatial_shape(self) -> TypeTripletInt:
212
        """Tensor spatial shape as :math:`(W, H, D)`."""
213
        return self.shape[1:]
214
215
    def check_is_2d(self) -> None:
216
        if not self.is_2d():
217
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
218
            raise RuntimeError(message)
219
220
    @property
221
    def height(self) -> int:
222
        """Image height, if 2D."""
223
        self.check_is_2d()
224
        return self.spatial_shape[1]
225
226
    @property
227
    def width(self) -> int:
228
        """Image width, if 2D."""
229
        self.check_is_2d()
230
        return self.spatial_shape[0]
231
232
    @property
233
    def orientation(self) -> Tuple[str, str, str]:
234
        """Orientation codes."""
235
        return nib.aff2axcodes(self.affine)
236
237
    @property
238
    def spacing(self) -> Tuple[float, float, float]:
239
        """Voxel spacing in mm."""
240
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
241
        return tuple(spacing)
242
243
    @property
244
    def memory(self) -> float:
245
        """Number of Bytes that the tensor takes in the RAM."""
246
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
247
248
    @property
249
    def bounds(self) -> np.ndarray:
250
        """Position of centers of voxels in smallest and largest coordinates."""
251
        ini = 0, 0, 0
252
        fin = np.array(self.spatial_shape) - 1
253
        point_ini = nib.affines.apply_affine(self.affine, ini)
254
        point_fin = nib.affines.apply_affine(self.affine, fin)
255
        return np.array((point_ini, point_fin))
256
257
    def axis_name_to_index(self, axis: str) -> int:
258
        """Convert an axis name to an axis index.
259
260
        Args:
261
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
262
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
263
                versions and first letters are also valid, as only the first
264
                letter will be used.
265
266
        .. note:: If you are working with animals, you should probably use
267
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
268
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
269
            respectively.
270
271
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
272
            ``'Left'`` and ``'Right'``.
273
        """
274
        # Top and bottom are used for the vertical 2D axis as the use of
275
        # Height vs Horizontal might be ambiguous
276
277
        if not isinstance(axis, str):
278
            raise ValueError('Axis must be a string')
279
        axis = axis[0].upper()
280
281
        # Generally, TorchIO tensors are (C, W, H, D)
282
        if axis in 'TB':  # Top, Bottom
283
            return -2
284
        else:
285
            try:
286
                index = self.orientation.index(axis)
287
            except ValueError:
288
                index = self.orientation.index(self.flip_axis(axis))
289
            # Return negative indices so that it does not matter whether we
290
            # refer to spatial dimensions or not
291
            index = -3 + index
292
            return index
293
294
    # flake8: noqa: E701
295
    @staticmethod
296
    def flip_axis(axis: str) -> str:
297
        if axis == 'R': return 'L'
298
        elif axis == 'L': return 'R'
299
        elif axis == 'A': return 'P'
300
        elif axis == 'P': return 'A'
301
        elif axis == 'I': return 'S'
302
        elif axis == 'S': return 'I'
303
        else:
304
            values = ', '.join('LRPAISTB')
305
            message = f'Axis not understood. Please use one of: {values}'
306
            raise ValueError(message)
307
308
    def get_spacing_string(self) -> str:
309
        strings = [f'{n:.2f}' for n in self.spacing]
310
        string = f'({", ".join(strings)})'
311
        return string
312
313
    def get_bounds(self) -> TypeBounds:
314
        """Get minimum and maximum world coordinates occupied by the image."""
315
        first_index = 3 * (-0.5,)
316
        last_index = np.array(self.spatial_shape) - 0.5
317
        first_point = nib.affines.apply_affine(self.affine, first_index)
318
        last_point = nib.affines.apply_affine(self.affine, last_index)
319
        array = np.array((first_point, last_point))
320
        bounds_x, bounds_y, bounds_z = array.T.tolist()
321
        return bounds_x, bounds_y, bounds_z
322
323
    @staticmethod
324
    def _parse_single_path(
325
            path: TypePath
326
            ) -> Path:
327
        try:
328
            path = Path(path).expanduser()
329
        except TypeError:
330
            message = (
331
                f'Expected type str or Path but found {path} with '
332
                f'{type(path)} instead'
333
            )
334
            raise TypeError(message)
335
        except RuntimeError:
336
            message = (
337
                f'Conversion to path not possible for variable: {path}'
338
            )
339
            raise RuntimeError(message)
340
341
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
342
            raise FileNotFoundError(f'File not found: "{path}"')
343
        return path
344
345
    def _parse_path(
346
            self,
347
            path: Union[TypePath, Sequence[TypePath]]
348
            ) -> Union[Path, List[Path]]:
349
        if path is None:
350
            return None
351
        if isinstance(path, (str, Path)):
352
            return self._parse_single_path(path)
353
        else:
354
            return [self._parse_single_path(p) for p in path]
355
356
    def _parse_tensor(
357
            self,
358
            tensor: TypeData,
359
            none_ok: bool = True,
360
            ) -> torch.Tensor:
361
        if tensor is None:
362
            if none_ok:
363
                return None
364
            else:
365
                raise RuntimeError('Input tensor cannot be None')
366
        if isinstance(tensor, np.ndarray):
367
            try:
368
                tensor = torch.from_numpy(tensor)
369
            except TypeError:
370
                if tensor.dtype == np.uint16:
371
                    tensor = torch.from_numpy(tensor.astype(np.int32))
372
                elif tensor.dtype == np.uint32:
373
                    tensor = torch.from_numpy(tensor.astype(np.int64))
374
                else:
375
                    raise
376
        elif not isinstance(tensor, torch.Tensor):
377
            message = 'Input tensor must be a PyTorch tensor or NumPy array'
378
            raise TypeError(message)
379
        if tensor.ndim != 4:
380
            raise ValueError('Input tensor must be 4D')
381
        if tensor.dtype == torch.bool:
382
            tensor = tensor.to(torch.uint8)
383
        if self.check_nans and torch.isnan(tensor).any():
384
            warnings.warn(f'NaNs found in tensor', RuntimeWarning)
385
        return tensor
386
387
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
388
        return ensure_4d(tensor)
389
390
    @staticmethod
391
    def _parse_affine(affine: np.ndarray) -> np.ndarray:
392
        if affine is None:
393
            return np.eye(4)
394
        if not isinstance(affine, np.ndarray):
395
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
396
        if affine.shape != (4, 4):
397
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
398
        return affine
399
400
    def load(self) -> None:
401
        r"""Load the image from disk.
402
403
        Returns:
404
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
405
            :math:`4 \times 4` affine matrix to convert voxel indices to world
406
            coordinates.
407
        """
408
        if self._loaded:
409
            return
410
        paths = self.path if isinstance(self.path, list) else [self.path]
411
        tensor, affine = self.read_and_check(paths[0])
412
        tensors = [tensor]
413
        for path in paths[1:]:
414
            new_tensor, new_affine = self.read_and_check(path)
415
            if not np.array_equal(affine, new_affine):
416
                message = (
417
                    'Files have different affine matrices.'
418
                    f'\nMatrix of {paths[0]}:'
419
                    f'\n{affine}'
420
                    f'\nMatrix of {path}:'
421
                    f'\n{new_affine}'
422
                )
423
                warnings.warn(message, RuntimeWarning)
424
            if not tensor.shape[1:] == new_tensor.shape[1:]:
425
                message = (
426
                    f'Files shape do not match, found {tensor.shape}'
427
                    f'and {new_tensor.shape}'
428
                )
429
                RuntimeError(message)
430
            tensors.append(new_tensor)
431
        tensor = torch.cat(tensors)
432
        self.data = tensor
433
        self.affine = affine
434
        self._loaded = True
435
436
    def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
437
        tensor, affine = read_image(path)
438
        tensor = self.parse_tensor_shape(tensor)
439
        if self.channels_last:
440
            tensor = tensor.permute(3, 0, 1, 2)
441
        if self.check_nans and torch.isnan(tensor).any():
442
            warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
443
        return tensor, affine
444
445
    def save(self, path: TypePath, squeeze: bool = True) -> None:
446
        """Save image to disk.
447
448
        Args:
449
            path: String or instance of :class:`pathlib.Path`.
450
            squeeze: If ``True``, singleton dimensions will be removed
451
                before saving.
452
        """
453
        write_image(
454
            self[DATA],
455
            self.affine,
456
            path,
457
            squeeze=squeeze,
458
        )
459
460
    def is_2d(self) -> bool:
461
        return self.shape[-1] == 1
462
463
    def numpy(self) -> np.ndarray:
464
        """Get a NumPy array containing the image data."""
465
        return np.asarray(self)
466
467
    def as_sitk(self, **kwargs) -> sitk.Image:
468
        """Get the image as an instance of :class:`sitk.Image`."""
469
        return nib_to_sitk(self[DATA], self.affine, **kwargs)
470
471
    def as_pil(self) -> ImagePIL:
472
        """Get the image as an instance of :class:`PIL.Image`."""
473
        self.check_is_2d()
474
        return ImagePIL.open(self.path)
475
476
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
477
        """Get image center in RAS+ or LPS+ coordinates.
478
479
        Args:
480
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
481
                the first dimension grows towards the left, etc. Otherwise, the
482
                coordinates will be in RAS+ orientation.
483
        """
484
        size = np.array(self.spatial_shape)
485
        center_index = (size - 1) / 2
486
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
487
        if lps:
488
            return (-r, -a, s)
489
        else:
490
            return (r, a, s)
491
492
    def set_check_nans(self, check_nans: bool) -> None:
493
        self.check_nans = check_nans
494
495
    def plot(self, **kwargs) -> None:
496
        """Plot image."""
497
        if self.is_2d():
498
            self.as_pil().show()
499
        else:
500
            from ..visualization import plot_volume  # avoid circular import
501
            plot_volume(self, **kwargs)
502
503
504
class ScalarImage(Image):
505
    """Image whose pixel values represent scalars.
506
507
    Example:
508
        >>> import torch
509
        >>> import torchio as tio
510
        >>> # Loading from a file
511
        >>> t1_image = tio.ScalarImage('t1.nii.gz')
512
        >>> dmri = tio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
513
        >>> image = tio.ScalarImage('safe_image.nrrd', check_nans=False)
514
        >>> data, affine = image.data, image.affine
515
        >>> affine.shape
516
        (4, 4)
517
        >>> image.data is image[tio.DATA]
518
        True
519
        >>> image.data is image.tensor
520
        True
521
        >>> type(image.data)
522
        torch.Tensor
523
524
    See :class:`~torchio.Image` for more information.
525
    """
526
    def __init__(self, *args, **kwargs):
527
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
528
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
529
        kwargs.update({'type': INTENSITY})
530
        super().__init__(*args, **kwargs)
531
532
533
class LabelMap(Image):
534
    """Image whose pixel values represent categorical labels.
535
536
    Intensity transforms are not applied to these images.
537
538
    Example:
539
        >>> import torch
540
        >>> import torchio as tio
541
        >>> labels = tio.LabelMap(tensor=torch.rand(1, 128, 128, 68) > 0.5)
542
        >>> labels = tio.LabelMap('t1_seg.nii.gz')  # loading from a file
543
        >>> tpm = tio.LabelMap(                     # loading from files
544
        ...     'gray_matter.nii.gz',
545
        ...     'white_matter.nii.gz',
546
        ...     'csf.nii.gz',
547
        ... )
548
549
    See :class:`~torchio.Image` for more information.
550
    """
551
    def __init__(self, *args, **kwargs):
552
        if 'type' in kwargs and kwargs['type'] != LABEL:
553
            raise ValueError('Type of LabelMap is always torchio.LABEL')
554
        kwargs.update({'type': LABEL})
555
        super().__init__(*args, **kwargs)
556