Passed
Pull Request — master (#248)
by Fernando
01:11
created

torchio.data.image.Image.__getitem__()   A

Complexity

Conditions 3

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 5
nop 2
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read. The data is expected to be 2D or
48
            3D, and may have multiple channels (see :attr:`num_spatial_dims` and
49
            :attr:`channels_last`).
50
        type: Type of image, such as :attr:`torchio.INTENSITY` or
51
            :attr:`torchio.LABEL`. This will be used by the transforms to
52
            decide whether to apply an operation, or which interpolation to use
53
            when resampling. For example, `preprocessing`_ and `augmentation`_
54
            intensity transforms will only be applied to images with type
55
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
56
            all types, and nearest neighbor interpolation is always used to
57
            resample images with type :attr:`torchio.LABEL`.
58
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
59
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
60
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
61
            :py:class:`torch.Tensor` or NumPy array with dimensions
62
            :math:`(C, D, H, W)`. If it is not 4D, TorchIO will try to guess
63
            the dimensions meanings. If 2D, the shape will be interpreted as
64
            :math:`(H, W)`. If 3D, the number of spatial dimensions should be
65
            determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims`
66
            is not given and the shape is 3 along the first or last dimensions,
67
            it will be interpreted as a multichannel 2D image. Otherwise, it
68
            be interpreted as a 3D image with a single channel.
69
        affine: If :attr:`path` is not given, :attr:`affine` must be a
70
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
71
            identity matrix.
72
        check_nans: If ``True``, issues a warning if NaNs are found
73
            in the image. If ``False``, images will not be checked for the
74
            presence of NaNs.
75
        num_spatial_dims: If ``2`` and the input tensor has 3 dimensions, it
76
            will be interpreted as a multichannel 2D image. If ``3`` and the
77
            input has 3 dimensions, it will be interpreted as a
78
            single-channel 3D volume.
79
        channels_last: If ``True``, the last dimension of the input will be
80
            interpreted as the channels. Defaults to ``True`` if :attr:`path` is
81
            given and ``False`` otherwise.
82
        **kwargs: Items that will be added to the image dictionary, e.g.
83
            acquisition parameters.
84
85
    Example:
86
        >>> import torch
87
        >>> import torchio
88
        >>> # Loading from a file
89
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
90
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
91
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
92
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
93
        >>> data, affine = image.data, image.affine
94
        >>> affine.shape
95
        (4, 4)
96
        >>> image.data is image[torchio.DATA]
97
        True
98
        >>> image.data is image.tensor
99
        True
100
        >>> type(image.data)
101
        torch.Tensor
102
103
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
104
    when needed.
105
106
    Example:
107
        >>> import torchio
108
        >>> image = torchio.Image('t1.nii.gz')
109
        >>> image  # not loaded yet
110
        Image(path: t1.nii.gz; type: intensity)
111
        >>> times_two = 2 * image.data  # data is loaded and cached here
112
        >>> image
113
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
114
        >>> image.save('doubled_image.nii.gz')
115
116
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
117
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
118
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
119
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
120
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
121
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
122
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
123
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
124
    """
125
    def __init__(
126
            self,
127
            path: Optional[TypePath] = None,
128
            type: str = INTENSITY,
129
            tensor: Optional[TypeData] = None,
130
            affine: Optional[TypeData] = None,
131
            check_nans: bool = True,
132
            num_spatial_dims: Optional[int] = None,
133
            channels_last: Optional[bool] = None,
134
            **kwargs: Dict[str, Any],
135
            ):
136
        self.check_nans = check_nans
137
        self.num_spatial_dims = num_spatial_dims
138
139
        if path is None and tensor is None:
140
            raise ValueError('A value for path or tensor must be given')
141
        self._loaded = False
142
143
        # Number of channels are typically stored in the last dimensions in disk
144
        # But if a tensor is given, the channels should be in the first dim
145
        if channels_last is None:
146
            channels_last = path is not None
147
        self.channels_last = channels_last
148
149
        tensor = self.parse_tensor(tensor)
150
        affine = self.parse_affine(affine)
151
        if tensor is not None:
152
            self[DATA] = tensor
153
            self[AFFINE] = affine
154
            self._loaded = True
155
        for key in PROTECTED_KEYS:
156
            if key in kwargs:
157
                message = f'Key "{key}" is reserved. Use a different one'
158
                raise ValueError(message)
159
160
        super().__init__(**kwargs)
161
        self.path = self._parse_path(path)
162
        self[PATH] = '' if self.path is None else str(self.path)
163
        self[STEM] = '' if self.path is None else get_stem(self.path)
164
        self[TYPE] = type
165
166
    def __repr__(self):
167
        properties = []
168
        if self._loaded:
169
            properties.extend([
170
                f'shape: {self.shape}',
171
                f'spacing: {self.get_spacing_string()}',
172
                f'orientation: {"".join(self.orientation)}+',
173
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
174
            ])
175
        else:
176
            properties.append(f'path: "{self.path}"')
177
        properties.append(f'type: {self.type}')
178
        properties = '; '.join(properties)
179
        string = f'{self.__class__.__name__}({properties})'
180
        return string
181
182
    def __getitem__(self, item):
183
        if item in (DATA, AFFINE):
184
            if item not in self:
185
                self._load()
186
        return super().__getitem__(item)
187
188
    def __array__(self):
189
        return self[DATA].numpy()
190
191
    def __copy__(self):
192
        kwargs = dict(
193
            tensor=self.data,
194
            affine=self.affine,
195
            type=self.type,
196
            path=self.path,
197
            channels_last=False,
198
        )
199
        for key, value in self.items():
200
            if key in PROTECTED_KEYS: continue
201
            kwargs[key] = value  # should I copy? deepcopy?
202
        return self.__class__(**kwargs)
203
204
    @property
205
    def data(self):
206
        return self[DATA]
207
208
    @property
209
    def tensor(self):
210
        return self.data
211
212
    @property
213
    def affine(self):
214
        return self[AFFINE]
215
216
    @property
217
    def type(self):
218
        return self[TYPE]
219
220
    @property
221
    def shape(self) -> Tuple[int, int, int, int]:
222
        return tuple(self.data.shape)
223
224
    @property
225
    def spatial_shape(self) -> TypeTripletInt:
226
        return self.shape[1:]
227
228
    @property
229
    def orientation(self):
230
        return nib.aff2axcodes(self.affine)
231
232
    @property
233
    def spacing(self):
234
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
235
        return tuple(spacing)
236
237
    @property
238
    def memory(self):
239
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
240
241
    def get_spacing_string(self):
242
        strings = [f'{n:.2f}' for n in self.spacing]
243
        string = f'({", ".join(strings)})'
244
        return string
245
246
    @staticmethod
247
    def _parse_path(path: TypePath) -> Path:
248
        if path is None:
249
            return None
250
        try:
251
            path = Path(path).expanduser()
252
        except TypeError:
253
            message = f'Conversion to path not possible for variable: {path}'
254
            raise TypeError(message)
255
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
256
            raise FileNotFoundError(f'File not found: {path}')
257
        return path
258
259
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
260
        if tensor is None:
261
            return None
262
        if isinstance(tensor, np.ndarray):
263
            tensor = torch.from_numpy(tensor)
264
        tensor = self.parse_tensor_shape(tensor)
265
        if self.check_nans and torch.isnan(tensor).any():
266
            warnings.warn(f'NaNs found in tensor')
267
        return tensor
268
269
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
270
        return ensure_4d(tensor, self.channels_last, self.num_spatial_dims)
271
272
    @staticmethod
273
    def parse_affine(affine: np.ndarray) -> np.ndarray:
274
        if affine is None:
275
            return np.eye(4)
276
        if not isinstance(affine, np.ndarray):
277
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
278
        if affine.shape != (4, 4):
279
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
280
        return affine
281
282
    def _load(self) -> Tuple[torch.Tensor, np.ndarray]:
283
        r"""Load the image from disk.
284
285
        Returns:
286
            Tuple containing a 4D tensor of size :math:`(C, D, H, W)` and a 2D
287
            :math:`4 \times 4` affine matrix to convert voxel indices to world
288
            coordinates.
289
        """
290
        if self._loaded:
291
            return
292
        tensor, affine = read_image(self.path)
293
        tensor = self.parse_tensor_shape(tensor)
294
295
        if self.check_nans and torch.isnan(tensor).any():
296
            warnings.warn(f'NaNs found in file "{self.path}"')
297
        self[DATA] = tensor
298
        self[AFFINE] = affine
299
        self._loaded = True
300
301
    def save(self, path, squeeze=True, channels_last=True):
302
        """Save image to disk.
303
304
        Args:
305
            path: String or instance of :py:class:`pathlib.Path`.
306
            squeeze: If ``True``, the singleton dimensions will be removed
307
                before saving.
308
            channels_last: If ``True``, the channels will be saved in the last
309
                dimension.
310
        """
311
        write_image(
312
            self[DATA],
313
            self[AFFINE],
314
            path,
315
            squeeze=squeeze,
316
            channels_last=channels_last,
317
        )
318
319
    def is_2d(self) -> bool:
320
        return self.shape[-3] == 1
321
322
    def numpy(self) -> np.ndarray:
323
        """Get a NumPy array containing the image data."""
324
        return np.asarray(self)
325
326
    def as_sitk(self) -> sitk.Image:
327
        """Get the image as an instance of :py:class:`sitk.Image`."""
328
        return nib_to_sitk(self[DATA], self[AFFINE])
329
330
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
331
        """Get image center in RAS+ or LPS+ coordinates.
332
333
        Args:
334
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
335
                the first dimension grows towards the left, etc. Otherwise, the
336
                coordinates will be in RAS+ orientation.
337
        """
338
        size = np.array(self.spatial_shape)
339
        center_index = (size - 1) / 2
340
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
341
        if lps:
342
            return (-r, -a, s)
343
        else:
344
            return (r, a, s)
345
346
    def set_check_nans(self, check_nans: bool):
347
        self.check_nans = check_nans
348
349
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
350
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
351
        new_affine = self.affine.copy()
352
        new_affine[:3, 3] = new_origin
353
        i0, j0, k0 = index_ini
354
        i1, j1, k1 = index_fin
355
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
356
        kwargs = dict(
357
            tensor=patch,
358
            affine=new_affine,
359
            type=self.type,
360
            path=self.path,
361
            channels_last=False,
362
        )
363
        for key, value in self.items():
364
            if key in PROTECTED_KEYS: continue
365
            kwargs[key] = value  # should I copy? deepcopy?
366
        return self.__class__(**kwargs)
367
368
369
class ScalarImage(Image):
370
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
371
372
    Example:
373
        >>> import torch
374
        >>> import torchio
375
        >>> image = torchio.ScalarImage('t1.nii.gz')  # loading from a file
376
        >>> image = torchio.ScalarImage(tensor=torch.rand(128, 128, 68))  # from tensor
377
        >>> data, affine = image.data, image.affine
378
        >>> affine.shape
379
        (4, 4)
380
        >>> image.data is image[torchio.DATA]
381
        True
382
        >>> image.data is image.tensor
383
        True
384
        >>> type(image.data)
385
        torch.Tensor
386
387
    See :py:class:`~torchio.Image` for more information.
388
389
    Raises:
390
        ValueError: A :py:attr:`type` is used for instantiation.
391
    """
392
    def __init__(self, *args, **kwargs):
393
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
394
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
395
        kwargs.update({'type': INTENSITY})
396
        super().__init__(*args, **kwargs)
397
398
399
class LabelMap(Image):
400
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
401
402
    Example:
403
        >>> import torch
404
        >>> import torchio
405
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
406
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
407
408
    See :py:class:`~torchio.data.image.Image` for more information.
409
410
    Raises:
411
        ValueError: If a value for :py:attr:`type` is given.
412
    """
413
    def __init__(self, *args, **kwargs):
414
        if 'type' in kwargs and kwargs['type'] != LABEL:
415
            raise ValueError('Type of LabelMap is always torchio.LABEL')
416
        kwargs.update({'type': LABEL})
417
        super().__init__(*args, **kwargs)
418