Passed
Push — master ( 03205f...9fbe6e )
by Fernando
01:15
created

torchio.data.image.Image.memory()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import nib_to_sitk, get_rotation_and_spacing_from_affine
12
from ..torchio import (
13
    TypePath,
14
    TypeTripletInt,
15
    TypeTripletFloat,
16
    DATA,
17
    TYPE,
18
    AFFINE,
19
    PATH,
20
    STEM,
21
    INTENSITY,
22
)
23
from .io import read_image
24
25
26
class Image(dict):
27
    r"""Class to store information about an image.
28
29
    Args:
30
        path: Path to a file that can be read by
31
            :mod:`SimpleITK` or :mod:`nibabel` or to a directory containing
32
            DICOM files.
33
        type: Type of image, such as :attr:`torchio.INTENSITY` or
34
            :attr:`torchio.LABEL`. This will be used by the transforms to
35
            decide whether to apply an operation, or which interpolation to use
36
            when resampling. For example,
37
            `preprocessing <https://torchio.readthedocs.io/transforms/preprocessing.html#intensity>`_
38
            and
39
            `augmentation <https://torchio.readthedocs.io/transforms/augmentation.html#intensity>`_
40
            intensity transforms will only be applied to images with type
41
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
42
            all types, and nearest neighbor interpolation is always used to
43
            resample images with type :attr:`torchio.LABEL`.
44
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 4D
45
            :py:class:`torch.Tensor` with dimensions :math:`(C, D, H, W)`,
46
            where :math:`C` is the number of channels and :math:`D, H, W`
47
            are the spatial dimensions.
48
        affine: If :attr:`path` is not given, :attr:`affine` must be a
49
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
50
            identity matrix.
51
        **kwargs: Items that will be added to image dictionary within the
52
            subject sample.
53
    """
54
    def __init__(
55
            self,
56
            path: Optional[TypePath] = None,
57
            type: str = INTENSITY,
58
            tensor: Optional[torch.Tensor] = None,
59
            affine: Optional[torch.Tensor] = None,
60
            **kwargs: Dict[str, Any],
61
            ):
62
        if path is None and tensor is None:
63
            raise ValueError('A value for path or tensor must be given')
64
        if path is not None:
65
            if tensor is not None or affine is not None:
66
                message = 'If a path is given, tensor and affine must be None'
67
                raise ValueError(message)
68
        self._tensor = self.parse_tensor(tensor)
69
        self._affine = self.parse_affine(affine)
70
        if self._affine is None:
71
            self._affine = np.eye(4)
72
        for key in (DATA, AFFINE, TYPE, PATH, STEM):
73
            if key in kwargs:
74
                message = f'Key "{key}" is reserved. Use a different one'
75
                raise ValueError(message)
76
77
        super().__init__(**kwargs)
78
        self.path = self._parse_path(path)
79
        self.type = type
80
        self.is_sample = False  # set to True by ImagesDataset
81
82
    def __repr__(self):
83
        properties = [
84
            f'shape: {self.shape}',
85
            f'spacing: {self.get_spacing_string()}',
86
            f'orientation: {"".join(self.orientation)}+',
87
            f'memory: {humanize.naturalsize(self.memory, binary=True)}',
88
        ]
89
        properties = '; '.join(properties)
90
        string = f'{self.__class__.__name__}({properties})'
91
        return string
92
93
    @property
94
    def data(self):
95
        return self[DATA]
96
97
    @property
98
    def affine(self):
99
        return self[AFFINE]
100
101
    @property
102
    def shape(self) -> Tuple[int, int, int, int]:
103
        return tuple(self[DATA].shape)
104
105
    @property
106
    def spatial_shape(self) -> TypeTripletInt:
107
        return self.shape[1:]
108
109
    @property
110
    def orientation(self):
111
        return nib.aff2axcodes(self[AFFINE])
112
113
    @property
114
    def spacing(self):
115
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
116
        return tuple(spacing)
117
118
    @property
119
    def memory(self):
120
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
121
122
    def get_spacing_string(self):
123
        strings = [f'{n:.2f}' for n in self.spacing]
124
        string = f'({", ".join(strings)})'
125
        return string
126
127
    @staticmethod
128
    def _parse_path(path: TypePath) -> Path:
129
        if path is None:
130
            return None
131
        try:
132
            path = Path(path).expanduser()
133
        except TypeError:
134
            message = f'Conversion to path not possible for variable: {path}'
135
            raise TypeError(message)
136
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
137
            raise FileNotFoundError(f'File not found: {path}')
138
        return path
139
140
    @staticmethod
141
    def parse_tensor(tensor: torch.Tensor) -> torch.Tensor:
142
        if tensor is None:
143
            return None
144
        num_dimensions = tensor.dim()
145
        if num_dimensions != 3:
146
            message = (
147
                'The input tensor must have 3 dimensions (D, H, W),'
148
                f' but has {num_dimensions}: {tensor.shape}'
149
            )
150
            raise RuntimeError(message)
151
        tensor = tensor.unsqueeze(0)  # add channels dimension
152
        return tensor
153
154
    @staticmethod
155
    def parse_affine(affine: np.ndarray) -> np.ndarray:
156
        if affine is None:
157
            return np.eye(4)
158
        if not isinstance(affine, np.ndarray):
159
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
160
        if affine.shape != (4, 4):
161
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
162
        return affine
163
164
    def load(self, check_nans: bool = True) -> Tuple[torch.Tensor, np.ndarray]:
165
        r"""Load the image from disk.
166
167
        The file is expected to be monomodal/grayscale and 2D or 3D.
168
        A channels dimension is added to the tensor.
169
170
        Args:
171
            check_nans: If ``True``, issues a warning if NaNs are found
172
                in the image
173
174
        Returns:
175
            Tuple containing a 4D data tensor of size
176
            :math:`(1, D_{in}, H_{in}, W_{in})`
177
            and a 2D 4x4 affine matrix
178
        """
179
        if self.path is None:
180
            return self._tensor, self._affine
181
        tensor, affine = read_image(self.path)
182
        # https://github.com/pytorch/pytorch/issues/9410#issuecomment-404968513
183
        tensor = tensor[(None,) * (3 - tensor.ndim)]  # force to be 3D
184
        # Remove next line and uncomment the two following ones once/if this issue
185
        # gets fixed:
186
        # https://github.com/pytorch/pytorch/issues/29010
187
        # See also https://discuss.pytorch.org/t/collating-named-tensors/78650/4
188
        tensor = tensor.unsqueeze(0)  # add channels dimension
189
        # name_dimensions(tensor, affine)
190
        # tensor = tensor.align_to('channels', ...)
191
        if check_nans and torch.isnan(tensor).any():
192
            warnings.warn(f'NaNs found in file "{self.path}"')
193
        return tensor, affine
194
195
    def is_2d(self) -> bool:
196
        return self.shape[-3] == 1
197
198
    def numpy(self) -> np.ndarray:
199
        return self[DATA].numpy()
200
201
    def as_sitk(self) -> sitk.Image:
202
        return nib_to_sitk(self[DATA], self[AFFINE])
203
204
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
205
        """Get image center in RAS (default) or LPS coordinates."""
206
        image = self.as_sitk()
207
        size = np.array(image.GetSize())
208
        center_index = (size - 1) / 2
209
        l, p, s = image.TransformContinuousIndexToPhysicalPoint(center_index)
210
        if lps:
211
            return (l, p, s)
212
        else:
213
            return (-l, -p, s)
214