Passed
Pull Request — master (#204)
by Fernando
59s
created

torchio.data.image.Image.load()   A

Complexity

Conditions 4

Size

Total Lines 28
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 11
nop 1
dl 0
loc 28
rs 9.85
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import nib_to_sitk, get_rotation_and_spacing_from_affine, get_stem
12
from ..torchio import (
13
    TypeData,
14
    TypePath,
15
    TypeTripletInt,
16
    TypeTripletFloat,
17
    DATA,
18
    TYPE,
19
    AFFINE,
20
    PATH,
21
    STEM,
22
    INTENSITY,
23
)
24
from .io import read_image, write_image
25
26
27
class Image(dict):
28
    r"""TorchIO image.
29
30
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
31
    when needed.
32
33
    Example:
34
        >>> import torchio
35
        >>> image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
36
        >>> image  # not loaded yet
37
        Image(path: t1.nii.gz; type: intensity)
38
        >>> times_two = 2 * image.data  # data is loaded and cached here
39
        >>> image
40
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
41
        >>> image.save('doubled_image.nii.gz')
42
43
    For information about medical image orientation, check out `NiBabel docs`_,
44
    the `3D Slicer wiki`_, `Graham Wideman's website`_ or the `FSL docs`_.
45
46
    Args:
47
        path: Path to a file that can be read by
48
            :mod:`SimpleITK` or :mod:`nibabel` or to a directory containing
49
            DICOM files.
50
        type: Type of image, such as :attr:`torchio.INTENSITY` or
51
            :attr:`torchio.LABEL`. This will be used by the transforms to
52
            decide whether to apply an operation, or which interpolation to use
53
            when resampling. For example, `preprocessing`_ and `augmentation`_
54
            intensity transforms will only be applied to images with type
55
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
56
            all types, and nearest neighbor interpolation is always used to
57
            resample images with type :attr:`torchio.LABEL`.
58
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
59
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
60
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 3D
61
            :py:class:`torch.Tensor` or NumPy array with dimensions
62
            :math:`(D, H, W)`.
63
        affine: If :attr:`path` is not given, :attr:`affine` must be a
64
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
65
            identity matrix.
66
        check_nans: If ``True``, issues a warning if NaNs are found
67
            in the image.
68
        **kwargs: Items that will be added to image dictionary within the
69
            subject sample.
70
71
    Example:
72
        >>> import torch
73
        >>> import torchio
74
        >>> # Loading from a file
75
        >>> image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
76
        >>> image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
77
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
78
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
79
        >>> data, affine = image.data, image.affine
80
        >>> affine.shape
81
        (4, 4)
82
        >>> image.data is image[torchio.DATA]
83
        True
84
        >>> image.data is image.tensor
85
        True
86
        >>> type(image.data)
87
        torch.Tensor
88
89
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
90
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
91
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
92
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
93
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
94
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
95
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
96
97
    """
98
    def __init__(
99
            self,
100
            path: Optional[TypePath] = None,
101
            type: str = INTENSITY,
102
            tensor: Optional[TypeData] = None,
103
            affine: Optional[TypeData] = None,
104
            check_nans: bool = True,
105
            **kwargs: Dict[str, Any],
106
            ):
107
        if path is None and tensor is None:
108
            raise ValueError('A value for path or tensor must be given')
109
        if path is not None:
110
            if tensor is not None or affine is not None:
111
                message = 'If a path is given, tensor and affine must be None'
112
                raise ValueError(message)
113
        self._loaded = False
114
        tensor = self.parse_tensor(tensor)
115
        affine = self.parse_affine(affine)
116
        if tensor is not None:
117
            if affine is None:
118
                affine = np.eye(4)
119
            self[DATA] = tensor
120
            self[AFFINE] = affine
121
            self._loaded = True
122
        for key in (DATA, AFFINE, TYPE, PATH, STEM):
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable TYPE does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable AFFINE does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable PATH does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable STEM does not seem to be defined.
Loading history...
123
            if key in kwargs:
124
                message = f'Key "{key}" is reserved. Use a different one'
125
                raise ValueError(message)
126
127
        super().__init__(**kwargs)
128
        self.path = self._parse_path(path)
129
        self[PATH] = '' if self.path is None else str(self.path)
130
        self[STEM] = '' if self.path is None else get_stem(self.path)
131
        self[TYPE] = type
132
        self.is_sample = False  # set to True by ImagesDataset
133
        self.check_nans = check_nans
134
135
    def __repr__(self):
136
        properties = []
137
        if self._loaded:
138
            properties.extend([
139
                f'shape: {self.shape}',
140
                f'spacing: {self.get_spacing_string()}',
141
                f'orientation: {"".join(self.orientation)}+',
142
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
143
            ])
144
        else:
145
            properties.append(f'path: {self.path}')
146
        properties.append(f'type: {self.type}')
147
        properties = '; '.join(properties)
148
        string = f'{self.__class__.__name__}({properties})'
149
        return string
150
151
    def __getitem__(self, item):
152
        if item in (DATA, AFFINE):
153
            if item not in self:
154
                self.load()
155
        return super().__getitem__(item)
156
157
    @property
158
    def data(self):
159
        return self[DATA]
160
161
    @property
162
    def tensor(self):
163
        return self.data
164
165
    @property
166
    def affine(self):
167
        return self[AFFINE]
168
169
    @property
170
    def type(self):
171
        return self[TYPE]
172
173
    @property
174
    def shape(self) -> Tuple[int, int, int, int]:
175
        return tuple(self.data.shape)
176
177
    @property
178
    def spatial_shape(self) -> TypeTripletInt:
179
        return self.shape[1:]
180
181
    @property
182
    def orientation(self):
183
        return nib.aff2axcodes(self.affine)
184
185
    @property
186
    def spacing(self):
187
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
188
        return tuple(spacing)
189
190
    @property
191
    def memory(self):
192
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
193
194
    def get_spacing_string(self):
195
        strings = [f'{n:.2f}' for n in self.spacing]
196
        string = f'({", ".join(strings)})'
197
        return string
198
199
    @staticmethod
200
    def _parse_path(path: TypePath) -> Path:
201
        if path is None:
202
            return None
203
        try:
204
            path = Path(path).expanduser()
205
        except TypeError:
206
            message = f'Conversion to path not possible for variable: {path}'
207
            raise TypeError(message)
208
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
209
            raise FileNotFoundError(f'File not found: {path}')
210
        return path
211
212
    @staticmethod
213
    def parse_tensor(tensor: TypeData) -> torch.Tensor:
214
        if tensor is None:
215
            return None
216
        if isinstance(tensor, np.ndarray):
217
            tensor = torch.from_numpy(tensor)
218
        num_dimensions = tensor.dim()
219
        if num_dimensions != 3:
220
            message = (
221
                'The input tensor must have 3 dimensions (D, H, W),'
222
                f' but has {num_dimensions}: {tensor.shape}'
223
            )
224
            raise RuntimeError(message)
225
        tensor = tensor.unsqueeze(0)  # add channels dimension
226
        tensor = tensor.float()
227
        return tensor
228
229
    @staticmethod
230
    def parse_affine(affine: np.ndarray) -> np.ndarray:
231
        if affine is None:
232
            return np.eye(4)
233
        if not isinstance(affine, np.ndarray):
234
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
235
        if affine.shape != (4, 4):
236
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
237
        return affine
238
239
    def load(self) -> Tuple[torch.Tensor, np.ndarray]:
240
        r"""Load the image from disk.
241
242
        The file is expected to be monomodal/grayscale and 2D or 3D.
243
        A channels dimension is added to the tensor.
244
245
        Returns:
246
            Tuple containing a 4D data tensor of size
247
            :math:`(1, D_{in}, H_{in}, W_{in})`
248
            and a 2D 4x4 affine matrix
249
        """
250
        if self.path is None:
251
            raise RuntimeError('No path provided for instance of Image')
252
        tensor, affine = read_image(self.path)
253
        # https://github.com/pytorch/pytorch/issues/9410#issuecomment-404968513
254
        tensor = tensor[(None,) * (3 - tensor.ndim)]  # force to be 3D
255
        # Remove next line and uncomment the two following ones once/if this issue
256
        # gets fixed:
257
        # https://github.com/pytorch/pytorch/issues/29010
258
        # See also https://discuss.pytorch.org/t/collating-named-tensors/78650/4
259
        tensor = tensor.unsqueeze(0)  # add channels dimension
260
        # name_dimensions(tensor, affine)
261
        # tensor = tensor.align_to('channels', ...)
262
        if self.check_nans and torch.isnan(tensor).any():
263
            warnings.warn(f'NaNs found in file "{self.path}"')
264
        self[DATA] = tensor
265
        self[AFFINE] = affine
266
        self._loaded = True
267
268
    def save(self, path):
269
        """Save image to disk.
270
271
        Args:
272
            path: String or instance of :py:class:`pathlib.Path`.
273
        """
274
        tensor = self[DATA].squeeze()  # assume 2D if (1, 1, H, W)
275
        affine = self[AFFINE]
276
        write_image(tensor, affine, path)
277
278
    def is_2d(self) -> bool:
279
        return self.shape[-3] == 1
280
281
    def numpy(self) -> np.ndarray:
282
        return self[DATA].numpy()
283
284
    def as_sitk(self) -> sitk.Image:
285
        return nib_to_sitk(self[DATA][0], self[AFFINE])
286
287
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
288
        """Get image center in RAS (default) or LPS coordinates."""
289
        image = self.as_sitk()
290
        size = np.array(image.GetSize())
291
        center_index = (size - 1) / 2
292
        l, p, s = image.TransformContinuousIndexToPhysicalPoint(center_index)
293
        if lps:
294
            return (l, p, s)
295
        else:
296
            return (-l, -p, s)
297
298
    def set_check_nans(self, check_nans):
299
        self.check_nans = check_nans
300