Passed
Pull Request — master (#223)
by Fernando
01:49
created

torchio.data.image.Image.memory()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import nib_to_sitk, get_rotation_and_spacing_from_affine, get_stem
12
from ..torchio import (
13
    TypeData,
14
    TypePath,
15
    TypeTripletInt,
16
    TypeTripletFloat,
17
    DATA,
18
    TYPE,
19
    AFFINE,
20
    PATH,
21
    STEM,
22
    INTENSITY,
23
)
24
from .io import read_image, write_image
25
26
27
class Image(dict):
28
    r"""TorchIO image.
29
30
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
31
    when needed.
32
33
    Example:
34
        >>> import torchio
35
        >>> image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
36
        >>> image  # not loaded yet
37
        Image(path: t1.nii.gz; type: intensity)
38
        >>> times_two = 2 * image.data  # data is loaded and cached here
39
        >>> image
40
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
41
        >>> image.save('doubled_image.nii.gz')
42
43
    For information about medical image orientation, check out `NiBabel docs`_,
44
    the `3D Slicer wiki`_, `Graham Wideman's website`_ or the `FSL docs`_.
45
46
    Args:
47
        path: Path to a file that can be read by
48
            :mod:`SimpleITK` or :mod:`nibabel` or to a directory containing
49
            DICOM files.
50
        type: Type of image, such as :attr:`torchio.INTENSITY` or
51
            :attr:`torchio.LABEL`. This will be used by the transforms to
52
            decide whether to apply an operation, or which interpolation to use
53
            when resampling. For example, `preprocessing`_ and `augmentation`_
54
            intensity transforms will only be applied to images with type
55
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
56
            all types, and nearest neighbor interpolation is always used to
57
            resample images with type :attr:`torchio.LABEL`.
58
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
59
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
60
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 3D
61
            :py:class:`torch.Tensor` or NumPy array with dimensions
62
            :math:`(D, H, W)`.
63
        affine: If :attr:`path` is not given, :attr:`affine` must be a
64
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
65
            identity matrix.
66
        check_nans: If ``True``, issues a warning if NaNs are found
67
            in the image.
68
        **kwargs: Items that will be added to image dictionary within the
69
            subject sample.
70
71
    Example:
72
        >>> import torch
73
        >>> import torchio
74
        >>> # Loading from a file
75
        >>> image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
76
        >>> image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
77
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
78
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
79
        >>> data, affine = image.data, image.affine
80
        >>> affine.shape
81
        (4, 4)
82
        >>> image.data is image[torchio.DATA]
83
        True
84
        >>> image.data is image.tensor
85
        True
86
        >>> type(image.data)
87
        torch.Tensor
88
89
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
90
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
91
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
92
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
93
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
94
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
95
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
96
97
    """
98
    def __init__(
99
            self,
100
            path: Optional[TypePath] = None,
101
            type: str = INTENSITY,
102
            tensor: Optional[TypeData] = None,
103
            affine: Optional[TypeData] = None,
104
            check_nans: bool = True,
105
            **kwargs: Dict[str, Any],
106
            ):
107
        if path is None and tensor is None:
108
            raise ValueError('A value for path or tensor must be given')
109
        if path is not None:
110
            if tensor is not None or affine is not None:
111
                message = 'If a path is given, tensor and affine must be None'
112
                raise ValueError(message)
113
        self._loaded = False
114
        tensor = self.parse_tensor(tensor)
115
        affine = self.parse_affine(affine)
116
        if tensor is not None:
117
            if affine is None:
118
                affine = np.eye(4)
119
            self[DATA] = tensor
120
            self[AFFINE] = affine
121
            self._loaded = True
122
        for key in (DATA, AFFINE, TYPE, PATH, STEM):
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable TYPE does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable AFFINE does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable PATH does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable STEM does not seem to be defined.
Loading history...
123
            if key in kwargs:
124
                message = f'Key "{key}" is reserved. Use a different one'
125
                raise ValueError(message)
126
127
        super().__init__(**kwargs)
128
        self.path = self._parse_path(path)
129
        self[PATH] = '' if self.path is None else str(self.path)
130
        self[STEM] = '' if self.path is None else get_stem(self.path)
131
        self[TYPE] = type
132
        self.check_nans = check_nans
133
134
    def __repr__(self):
135
        properties = []
136
        if self._loaded:
137
            properties.extend([
138
                f'shape: {self.shape}',
139
                f'spacing: {self.get_spacing_string()}',
140
                f'orientation: {"".join(self.orientation)}+',
141
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
142
            ])
143
        else:
144
            properties.append(f'path: {self.path}')
145
        properties.append(f'type: {self.type}')
146
        properties = '; '.join(properties)
147
        string = f'{self.__class__.__name__}({properties})'
148
        return string
149
150
    def __getitem__(self, item):
151
        if item in (DATA, AFFINE):
152
            if item not in self:
153
                self.load()
154
        return super().__getitem__(item)
155
156
    @property
157
    def data(self):
158
        return self[DATA]
159
160
    @property
161
    def tensor(self):
162
        return self.data
163
164
    @property
165
    def affine(self):
166
        return self[AFFINE]
167
168
    @property
169
    def type(self):
170
        return self[TYPE]
171
172
    @property
173
    def shape(self) -> Tuple[int, int, int, int]:
174
        return tuple(self.data.shape)
175
176
    @property
177
    def spatial_shape(self) -> TypeTripletInt:
178
        return self.shape[1:]
179
180
    @property
181
    def orientation(self):
182
        return nib.aff2axcodes(self.affine)
183
184
    @property
185
    def spacing(self):
186
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
187
        return tuple(spacing)
188
189
    @property
190
    def memory(self):
191
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
192
193
    def get_spacing_string(self):
194
        strings = [f'{n:.2f}' for n in self.spacing]
195
        string = f'({", ".join(strings)})'
196
        return string
197
198
    @staticmethod
199
    def _parse_path(path: TypePath) -> Path:
200
        if path is None:
201
            return None
202
        try:
203
            path = Path(path).expanduser()
204
        except TypeError:
205
            message = f'Conversion to path not possible for variable: {path}'
206
            raise TypeError(message)
207
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
208
            raise FileNotFoundError(f'File not found: {path}')
209
        return path
210
211
    @staticmethod
212
    def parse_tensor(tensor: TypeData) -> torch.Tensor:
213
        if tensor is None:
214
            return None
215
        if isinstance(tensor, np.ndarray):
216
            tensor = torch.from_numpy(tensor)
217
        num_dimensions = tensor.dim()
218
        if num_dimensions != 3:
219
            message = (
220
                'The input tensor must have 3 dimensions (D, H, W),'
221
                f' but has {num_dimensions}: {tensor.shape}'
222
            )
223
            raise RuntimeError(message)
224
        tensor = tensor.unsqueeze(0)  # add channels dimension
225
        tensor = tensor.float()
226
        return tensor
227
228
    @staticmethod
229
    def parse_affine(affine: np.ndarray) -> np.ndarray:
230
        if affine is None:
231
            return np.eye(4)
232
        if not isinstance(affine, np.ndarray):
233
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
234
        if affine.shape != (4, 4):
235
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
236
        return affine
237
238
    def load(self) -> Tuple[torch.Tensor, np.ndarray]:
239
        r"""Load the image from disk.
240
241
        The file is expected to be monomodal/grayscale and 2D or 3D.
242
        A channels dimension is added to the tensor.
243
244
        Returns:
245
            Tuple containing a 4D data tensor of size
246
            :math:`(1, D_{in}, H_{in}, W_{in})`
247
            and a 2D 4x4 affine matrix
248
        """
249
        if self._loaded:
250
            return
251
        if self.path is None:
252
            return
253
        tensor, affine = read_image(self.path)
254
        # https://github.com/pytorch/pytorch/issues/9410#issuecomment-404968513
255
        tensor = tensor[(None,) * (3 - tensor.ndim)]  # force to be 3D
256
        # Remove next line and uncomment the two following ones once/if this issue
257
        # gets fixed:
258
        # https://github.com/pytorch/pytorch/issues/29010
259
        # See also https://discuss.pytorch.org/t/collating-named-tensors/78650/4
260
        tensor = tensor.unsqueeze(0)  # add channels dimension
261
        # name_dimensions(tensor, affine)
262
        # tensor = tensor.align_to('channels', ...)
263
        if self.check_nans and torch.isnan(tensor).any():
264
            warnings.warn(f'NaNs found in file "{self.path}"')
265
        self[DATA] = tensor
266
        self[AFFINE] = affine
267
        self._loaded = True
268
269
    def save(self, path):
270
        """Save image to disk.
271
272
        Args:
273
            path: String or instance of :py:class:`pathlib.Path`.
274
        """
275
        tensor = self[DATA].squeeze()  # assume 2D if (1, 1, H, W)
276
        affine = self[AFFINE]
277
        write_image(tensor, affine, path)
278
279
    def is_2d(self) -> bool:
280
        return self.shape[-3] == 1
281
282
    def numpy(self) -> np.ndarray:
283
        """Get a NumPy array containing the image data."""
284
        return self[DATA].numpy()
285
286
    def as_sitk(self) -> sitk.Image:
287
        """Get the image as an instance of :py:class:`sitk.Image`."""
288
        return nib_to_sitk(self[DATA][0], self[AFFINE])
289
290
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
291
        """Get image center in RAS+ or LPS+ coordinates.
292
293
        Args:
294
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
295
                the first dimension grows towards the left, etc. Otherwise, the
296
                coordinates will be in RAS+ orientation.
297
        """
298
        image = self.as_sitk()
299
        size = np.array(image.GetSize())
300
        center_index = (size - 1) / 2
301
        l, p, s = image.TransformContinuousIndexToPhysicalPoint(center_index)
302
        if lps:
303
            return (l, p, s)
304
        else:
305
            return (-l, -p, s)
306
307
    def set_check_nans(self, check_nans):
308
        self.check_nans = check_nans
309
310
    def crop(self, index_ini, index_fin):
311
        # TODO: add the rest of kwargs
312
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
313
        new_affine = self.affine.copy()
314
        new_affine[:3, 3] = new_origin
315
        i0, j0, k0 = index_ini
316
        i1, j1, k1 = index_fin
317
        patch = self.data[0, i0:i1, j0:j1, k0:k1].clone()
318
        return Image(tensor=patch, affine=new_affine, type=self.type)
319