Passed
Push — master ( 0d2a88...eb3c35 )
by Fernando
01:36
created

torchio.data.image.Image.data()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read. The data is expected to be 2D or
48
            3D, and may have multiple channels (see :attr:`num_spatial_dims` and
49
            :attr:`channels_last`).
50
        type: Type of image, such as :attr:`torchio.INTENSITY` or
51
            :attr:`torchio.LABEL`. This will be used by the transforms to
52
            decide whether to apply an operation, or which interpolation to use
53
            when resampling. For example, `preprocessing`_ and `augmentation`_
54
            intensity transforms will only be applied to images with type
55
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
56
            all types, and nearest neighbor interpolation is always used to
57
            resample images with type :attr:`torchio.LABEL`.
58
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
59
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
60
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
61
            :py:class:`torch.Tensor` or NumPy array with dimensions
62
            :math:`(C, D, H, W)`. If it is not 4D, TorchIO will try to guess
63
            the dimensions meanings. If 2D, the shape will be interpreted as
64
            :math:`(H, W)`. If 3D, the number of spatial dimensions should be
65
            determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims`
66
            is not given and the shape is 3 along the first or last dimensions,
67
            it will be interpreted as a multichannel 2D image. Otherwise, it
68
            be interpreted as a 3D image with a single channel.
69
        affine: If :attr:`path` is not given, :attr:`affine` must be a
70
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
71
            identity matrix.
72
        check_nans: If ``True``, issues a warning if NaNs are found
73
            in the image. If ``False``, images will not be checked for the
74
            presence of NaNs.
75
        num_spatial_dims: If ``2`` and the input tensor has 3 dimensions, it
76
            will be interpreted as a multichannel 2D image. If ``3`` and the
77
            input has 3 dimensions, it will be interpreted as a
78
            single-channel 3D volume.
79
        channels_last: If ``True``, the last dimension of the input will be
80
            interpreted as the channels. Defaults to ``True`` if :attr:`path` is
81
            given and ``False`` otherwise.
82
        **kwargs: Items that will be added to the image dictionary, e.g.
83
            acquisition parameters.
84
85
    Example:
86
        >>> import torch
87
        >>> import torchio
88
        >>> # Loading from a file
89
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
90
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
91
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
92
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
93
        >>> data, affine = image.data, image.affine
94
        >>> affine.shape
95
        (4, 4)
96
        >>> image.data is image[torchio.DATA]
97
        True
98
        >>> image.data is image.tensor
99
        True
100
        >>> type(image.data)
101
        torch.Tensor
102
103
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
104
    when needed.
105
106
    Example:
107
        >>> import torchio
108
        >>> image = torchio.Image('t1.nii.gz')
109
        >>> image  # not loaded yet
110
        Image(path: t1.nii.gz; type: intensity)
111
        >>> times_two = 2 * image.data  # data is loaded and cached here
112
        >>> image
113
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
114
        >>> image.save('doubled_image.nii.gz')
115
116
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
117
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
118
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
119
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
120
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
121
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
122
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
123
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
124
    """
125
    def __init__(
126
            self,
127
            path: Optional[TypePath] = None,
128
            type: str = INTENSITY,
129
            tensor: Optional[TypeData] = None,
130
            affine: Optional[TypeData] = None,
131
            check_nans: bool = True,
132
            num_spatial_dims: Optional[int] = None,
133
            channels_last: Optional[bool] = None,
134
            **kwargs: Dict[str, Any],
135
            ):
136
        self.check_nans = check_nans
137
        self.num_spatial_dims = num_spatial_dims
138
139
        if path is None and tensor is None:
140
            raise ValueError('A value for path or tensor must be given')
141
        self._loaded = False
142
143
        # Number of channels are typically stored in the last dimensions in disk
144
        # But if a tensor is given, the channels should be in the first dim
145
        if channels_last is None:
146
            channels_last = path is not None
147
        self.channels_last = channels_last
148
149
        tensor = self.parse_tensor(tensor)
150
        affine = self.parse_affine(affine)
151
        if tensor is not None:
152
            self[DATA] = tensor
153
            self[AFFINE] = affine
154
            self._loaded = True
155
        for key in PROTECTED_KEYS:
156
            if key in kwargs:
157
                message = f'Key "{key}" is reserved. Use a different one'
158
                raise ValueError(message)
159
160
        super().__init__(**kwargs)
161
        self.path = self._parse_path(path)
162
        self[PATH] = '' if self.path is None else str(self.path)
163
        self[STEM] = '' if self.path is None else get_stem(self.path)
164
        self[TYPE] = type
165
166
    def __repr__(self):
167
        properties = []
168
        if self._loaded:
169
            properties.extend([
170
                f'shape: {self.shape}',
171
                f'spacing: {self.get_spacing_string()}',
172
                f'orientation: {"".join(self.orientation)}+',
173
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
174
            ])
175
        else:
176
            properties.append(f'path: "{self.path}"')
177
        properties.append(f'type: {self.type}')
178
        properties = '; '.join(properties)
179
        string = f'{self.__class__.__name__}({properties})'
180
        return string
181
182
    def __getitem__(self, item):
183
        if item in (DATA, AFFINE):
184
            if item not in self:
185
                self._load()
186
        return super().__getitem__(item)
187
188
    def __array__(self):
189
        return self[DATA].numpy()
190
191
    @property
192
    def data(self):
193
        return self[DATA]
194
195
    @property
196
    def tensor(self):
197
        return self.data
198
199
    @property
200
    def affine(self):
201
        return self[AFFINE]
202
203
    @property
204
    def type(self):
205
        return self[TYPE]
206
207
    @property
208
    def shape(self) -> Tuple[int, int, int, int]:
209
        return tuple(self.data.shape)
210
211
    @property
212
    def spatial_shape(self) -> TypeTripletInt:
213
        return self.shape[1:]
214
215
    @property
216
    def orientation(self):
217
        return nib.aff2axcodes(self.affine)
218
219
    @property
220
    def spacing(self):
221
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
222
        return tuple(spacing)
223
224
    @property
225
    def memory(self):
226
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
227
228
    def get_spacing_string(self):
229
        strings = [f'{n:.2f}' for n in self.spacing]
230
        string = f'({", ".join(strings)})'
231
        return string
232
233
    @staticmethod
234
    def _parse_path(path: TypePath) -> Path:
235
        if path is None:
236
            return None
237
        try:
238
            path = Path(path).expanduser()
239
        except TypeError:
240
            message = f'Conversion to path not possible for variable: {path}'
241
            raise TypeError(message)
242
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
243
            raise FileNotFoundError(f'File not found: {path}')
244
        return path
245
246
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
247
        if tensor is None:
248
            return None
249
        if isinstance(tensor, np.ndarray):
250
            tensor = torch.from_numpy(tensor)
251
        tensor = self.parse_tensor_shape(tensor)
252
        if self.check_nans and torch.isnan(tensor).any():
253
            warnings.warn(f'NaNs found in tensor')
254
        return tensor
255
256
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
257
        return ensure_4d(tensor, self.channels_last, self.num_spatial_dims)
258
259
    @staticmethod
260
    def parse_affine(affine: np.ndarray) -> np.ndarray:
261
        if affine is None:
262
            return np.eye(4)
263
        if not isinstance(affine, np.ndarray):
264
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
265
        if affine.shape != (4, 4):
266
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
267
        return affine
268
269
    def _load(self) -> Tuple[torch.Tensor, np.ndarray]:
270
        r"""Load the image from disk.
271
272
        Returns:
273
            Tuple containing a 4D tensor of size :math:`(C, D, H, W)` and a 2D
274
            :math:`4 \times 4` affine matrix to convert voxel indices to world
275
            coordinates.
276
        """
277
        if self._loaded:
278
            return
279
        tensor, affine = read_image(self.path)
280
        tensor = self.parse_tensor_shape(tensor)
281
282
        if self.check_nans and torch.isnan(tensor).any():
283
            warnings.warn(f'NaNs found in file "{self.path}"')
284
        self[DATA] = tensor
285
        self[AFFINE] = affine
286
        self._loaded = True
287
288
    def save(self, path, squeeze=True, channels_last=True):
289
        """Save image to disk.
290
291
        Args:
292
            path: String or instance of :py:class:`pathlib.Path`.
293
            squeeze: If ``True``, the singleton dimensions will be removed
294
                before saving.
295
            channels_last: If ``True``, the channels will be saved in the last
296
                dimension.
297
        """
298
        write_image(
299
            self[DATA],
300
            self[AFFINE],
301
            path,
302
            squeeze=squeeze,
303
            channels_last=channels_last,
304
        )
305
306
    def is_2d(self) -> bool:
307
        return self.shape[-3] == 1
308
309
    def numpy(self) -> np.ndarray:
310
        """Get a NumPy array containing the image data."""
311
        return np.asarray(self)
312
313
    def as_sitk(self) -> sitk.Image:
314
        """Get the image as an instance of :py:class:`sitk.Image`."""
315
        return nib_to_sitk(self[DATA], self[AFFINE])
316
317
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
318
        """Get image center in RAS+ or LPS+ coordinates.
319
320
        Args:
321
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
322
                the first dimension grows towards the left, etc. Otherwise, the
323
                coordinates will be in RAS+ orientation.
324
        """
325
        size = np.array(self.spatial_shape)
326
        center_index = (size - 1) / 2
327
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
328
        if lps:
329
            return (-r, -a, s)
330
        else:
331
            return (r, a, s)
332
333
    def set_check_nans(self, check_nans: bool):
334
        self.check_nans = check_nans
335
336
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
337
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
338
        new_affine = self.affine.copy()
339
        new_affine[:3, 3] = new_origin
340
        i0, j0, k0 = index_ini
341
        i1, j1, k1 = index_fin
342
        patch = self.data[0, i0:i1, j0:j1, k0:k1].clone()
343
        kwargs = dict(tensor=patch, affine=new_affine, type=self.type, path=self.path)
344
        for key, value in self.items():
345
            if key in PROTECTED_KEYS: continue
346
            kwargs[key] = value  # should I copy? deepcopy?
347
        return self.__class__(**kwargs)
348
349
350
class ScalarImage(Image):
351
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
352
353
    Example:
354
        >>> import torch
355
        >>> import torchio
356
        >>> image = torchio.ScalarImage('t1.nii.gz')  # loading from a file
357
        >>> image = torchio.ScalarImage(tensor=torch.rand(128, 128, 68))  # from tensor
358
        >>> data, affine = image.data, image.affine
359
        >>> affine.shape
360
        (4, 4)
361
        >>> image.data is image[torchio.DATA]
362
        True
363
        >>> image.data is image.tensor
364
        True
365
        >>> type(image.data)
366
        torch.Tensor
367
368
    See :py:class:`~torchio.Image` for more information.
369
370
    Raises:
371
        ValueError: A :py:attr:`type` is used for instantiation.
372
    """
373
    def __init__(self, *args, **kwargs):
374
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
375
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
376
        kwargs.update({'type': INTENSITY})
377
        super().__init__(*args, **kwargs)
378
379
380
class LabelMap(Image):
381
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
382
383
    Example:
384
        >>> import torch
385
        >>> import torchio
386
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
387
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
388
389
    See :py:class:`~torchio.data.image.Image` for more information.
390
391
    Raises:
392
        ValueError: If a value for :py:attr:`type` is given.
393
    """
394
    def __init__(self, *args, **kwargs):
395
        if 'type' in kwargs and kwargs['type'] != LABEL:
396
            raise ValueError('Type of LabelMap is always torchio.LABEL')
397
        kwargs.update({'type': LABEL})
398
        super().__init__(*args, **kwargs)
399