Passed
Push — master ( a5fd0f...582603 )
by Fernando
01:13
created

torchio.data.image.Image.get_center()   A

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 8
nop 2
dl 0
loc 10
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ..utils import nib_to_sitk
10
from ..torchio import (
11
    TypePath,
12
    TypeTripletInt,
13
    TypeTripletFloat,
14
    DATA,
15
    TYPE,
16
    AFFINE,
17
    PATH,
18
    STEM,
19
    INTENSITY,
20
)
21
from .io import read_image
22
23
24
class Image(dict):
25
    r"""Class to store information about an image.
26
27
    Args:
28
        path: Path to a file that can be read by
29
            :mod:`SimpleITK` or :mod:`nibabel` or to a directory containing
30
            DICOM files.
31
        type: Type of image, such as :attr:`torchio.INTENSITY` or
32
            :attr:`torchio.LABEL`. This will be used by the transforms to
33
            decide whether to apply an operation, or which interpolation to use
34
            when resampling.
35
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
36
            :py:class:`torch.Tensor` with dimensions :math:`(C, D, H, W)`,
37
            where :math:`C` is the number of channels and :math:`D, H, W`
38
            are the spatial dimensions.
39
        affine: If :attr:`path` is not given, :attr:`affine` must be a
40
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
41
            identity matrix.
42
        **kwargs: Items that will be added to image dictionary within the
43
            subject sample.
44
    """
45
    def __init__(
46
            self,
47
            path: Optional[TypePath] = None,
48
            type: str = INTENSITY,
49
            tensor: Optional[torch.Tensor] = None,
50
            affine: Optional[torch.Tensor] = None,
51
            **kwargs: Dict[str, Any],
52
            ):
53
        if path is None and tensor is None:
54
            raise ValueError('A value for path or tensor must be given')
55
        if path is not None:
56
            if tensor is not None or affine is not None:
57
                message = 'If a path is given, tensor and affine must be None'
58
                raise ValueError(message)
59
        self.tensor = self.parse_tensor(tensor)
60
        self.affine = self.parse_affine(affine)
61
        if self.affine is None:
62
            self.affine = np.eye(4)
63
        for key in (DATA, AFFINE, TYPE, PATH, STEM):
64
            if key in kwargs:
65
                raise ValueError(f'Key {key} is reserved. Use a different one')
66
67
        super().__init__(**kwargs)
68
        self.path = self._parse_path(path)
69
        self.type = type
70
        self.is_sample = False  # set to True by ImagesDataset
71
72
    @property
73
    def shape(self) -> Tuple[int, int, int, int]:
74
        return self[DATA].shape
75
76
    @property
77
    def spatial_shape(self) -> TypeTripletInt:
78
        return self.shape[1:]
79
80
    @staticmethod
81
    def _parse_path(path: TypePath) -> Path:
82
        if path is None:
83
            return None
84
        try:
85
            path = Path(path).expanduser()
86
        except TypeError:
87
            message = f'Conversion to path not possible for variable: {path}'
88
            raise TypeError(message)
89
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
90
            raise FileNotFoundError(f'File not found: {path}')
91
        return path
92
93
    @staticmethod
94
    def parse_tensor(tensor: torch.Tensor) -> torch.Tensor:
95
        if tensor is None:
96
            return None
97
        num_dimensions = tensor.dim()
98
        if num_dimensions != 3:
99
            message = (
100
                'The input tensor must have 3 dimensions (D, H, W),'
101
                f' but has {num_dimensions}: {tensor.shape}'
102
            )
103
            raise RuntimeError(message)
104
        tensor = tensor.unsqueeze(0)  # add channels dimension
105
        return tensor
106
107
    @staticmethod
108
    def parse_affine(affine: np.ndarray) -> np.ndarray:
109
        if affine is None:
110
            return np.eye(4)
111
        if not isinstance(affine, np.ndarray):
112
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
113
        if affine.shape != (4, 4):
114
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
115
        return affine
116
117
    def load(self, check_nans: bool = True) -> Tuple[torch.Tensor, np.ndarray]:
118
        r"""Load the image from disk.
119
120
        The file is expected to be monomodal/grayscale and 2D or 3D.
121
        A channels dimension is added to the tensor.
122
123
        Args:
124
            check_nans: If ``True``, issues a warning if NaNs are found
125
                in the image
126
127
        Returns:
128
            Tuple containing a 4D data tensor of size
129
            :math:`(1, D_{in}, H_{in}, W_{in})`
130
            and a 2D 4x4 affine matrix
131
        """
132
        if self.path is None:
133
            return self.tensor, self.affine
134
        tensor, affine = read_image(self.path)
135
        # https://github.com/pytorch/pytorch/issues/9410#issuecomment-404968513
136
        tensor = tensor[(None,) * (3 - tensor.ndim)]  # force to be 3D
137
        tensor = tensor.unsqueeze(0)  # add channels dimension
138
        if check_nans and torch.isnan(tensor).any():
139
            warnings.warn(f'NaNs found in file "{self.path}"')
140
        return tensor, affine
141
142
    def is_2d(self) -> bool:
143
        return self.shape[-3] == 1
144
145
    def numpy(self) -> np.ndarray:
146
        return self[DATA].numpy()
147
148
    def as_sitk(self) -> sitk.Image:
149
        return nib_to_sitk(self.data, self.affine)
150
151
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
152
        """Get image center in RAS (default) or LPS coordinates."""
153
        image = self.as_sitk()
154
        size = np.array(image.GetSize())
155
        center_index = (size - 1) / 2
156
        l, p, s = image.TransformContinuousIndexToPhysicalPoint(center_index)
157
        if lps:
158
            return (l, p, s)
159
        else:
160
            return (-l, -p, s)
161