Passed
Pull Request — master (#246)
by Fernando
01:19
created

torchio.data.image.Image.spatial_shape()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read. The data is expected to be 2D or
48
            3D, and may have multiple channels (see :attr:`num_spatial_dims` and
49
            :attr:`channels_last`).
50
        type: Type of image, such as :attr:`torchio.INTENSITY` or
51
            :attr:`torchio.LABEL`. This will be used by the transforms to
52
            decide whether to apply an operation, or which interpolation to use
53
            when resampling. For example, `preprocessing`_ and `augmentation`_
54
            intensity transforms will only be applied to images with type
55
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
56
            all types, and nearest neighbor interpolation is always used to
57
            resample images with type :attr:`torchio.LABEL`.
58
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
59
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
60
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
61
            :py:class:`torch.Tensor` or NumPy array with dimensions
62
            :math:`(C, D, H, W)`. If it is not 4D, TorchIO will try to guess
63
            the dimensions meanings. If 2D, the shape will be interpreted as
64
            :math:`(H, W)`. If 3D, the number of spatial dimensions should be
65
            determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims`
66
            is not given and the shape is 3 along the first or last dimensions,
67
            it will be interpreted as a multichannel 2D image. Otherwise, it
68
            be interpreted as a 3D image with a single channel.
69
        affine: If :attr:`path` is not given, :attr:`affine` must be a
70
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
71
            identity matrix.
72
        check_nans: If ``True``, issues a warning if NaNs are found
73
            in the image. If ``False``, images will not be checked for the
74
            presence of NaNs.
75
        num_spatial_dims: If ``2`` and the input tensor has 3 dimensions, it
76
            will be interpreted as a multichannel 2D image. If ``3`` and the
77
            input has 3 dimensions, it will be interpreted as a
78
            single-channel 3D volume.
79
        channels_last: If ``True``, the last dimension of the input will be
80
            interpreted as the channels. Defaults to ``True`` if :attr:`path` is
81
            given and ``False`` otherwise.
82
        **kwargs: Items that will be added to the image dictionary, e.g.
83
            acquisition parameters.
84
85
    Example:
86
        >>> import torch
87
        >>> import torchio
88
        >>> # Loading from a file
89
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
90
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
91
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
92
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
93
        >>> data, affine = image.data, image.affine
94
        >>> affine.shape
95
        (4, 4)
96
        >>> image.data is image[torchio.DATA]
97
        True
98
        >>> image.data is image.tensor
99
        True
100
        >>> type(image.data)
101
        torch.Tensor
102
103
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
104
    when needed.
105
106
    Example:
107
        >>> import torchio
108
        >>> image = torchio.Image('t1.nii.gz')
109
        >>> image  # not loaded yet
110
        Image(path: t1.nii.gz; type: intensity)
111
        >>> times_two = 2 * image.data  # data is loaded and cached here
112
        >>> image
113
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
114
        >>> image.save('doubled_image.nii.gz')
115
116
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
117
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
118
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
119
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
120
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
121
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
122
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
123
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
124
    """
125
    def __init__(
126
            self,
127
            path: Optional[TypePath] = None,
128
            type: str = INTENSITY,
129
            tensor: Optional[TypeData] = None,
130
            affine: Optional[TypeData] = None,
131
            check_nans: bool = True,
132
            num_spatial_dims: Optional[int] = None,
133
            channels_last: Optional[bool] = None,
134
            **kwargs: Dict[str, Any],
135
            ):
136
        if path is None and tensor is None:
137
            raise ValueError('A value for path or tensor must be given')
138
        self._loaded = False
139
        self.num_spatial_dims = num_spatial_dims
140
141
        # Number of channels are typically stored in the last dimensions in disk
142
        # But if a tensor is given, the channels should be in the first dim
143
        if channels_last is None:
144
            channels_last = path is not None
145
        self.channels_last = channels_last
146
147
        tensor = self.parse_tensor(tensor)
148
        affine = self.parse_affine(affine)
149
        if tensor is not None:
150
            if affine is None:
151
                affine = np.eye(4)
152
            self[DATA] = tensor
153
            self[AFFINE] = affine
154
            self._loaded = True
155
        for key in PROTECTED_KEYS:
156
            if key in kwargs:
157
                message = f'Key "{key}" is reserved. Use a different one'
158
                raise ValueError(message)
159
160
        super().__init__(**kwargs)
161
        self.path = self._parse_path(path)
162
        self[PATH] = '' if self.path is None else str(self.path)
163
        self[STEM] = '' if self.path is None else get_stem(self.path)
164
        self[TYPE] = type
165
        self.check_nans = check_nans
166
167
    def __repr__(self):
168
        properties = []
169
        if self._loaded:
170
            properties.extend([
171
                f'shape: {self.shape}',
172
                f'spacing: {self.get_spacing_string()}',
173
                f'orientation: {"".join(self.orientation)}+',
174
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
175
            ])
176
        else:
177
            properties.append(f'path: "{self.path}"')
178
        properties.append(f'type: {self.type}')
179
        properties = '; '.join(properties)
180
        string = f'{self.__class__.__name__}({properties})'
181
        return string
182
183
    def __getitem__(self, item):
184
        if item in (DATA, AFFINE):
185
            if item not in self:
186
                self._load()
187
        return super().__getitem__(item)
188
189
    def __array__(self):
190
        return self[DATA].numpy()
191
192
    @property
193
    def data(self):
194
        return self[DATA]
195
196
    @property
197
    def tensor(self):
198
        return self.data
199
200
    @property
201
    def affine(self):
202
        return self[AFFINE]
203
204
    @property
205
    def type(self):
206
        return self[TYPE]
207
208
    @property
209
    def shape(self) -> Tuple[int, int, int, int]:
210
        return tuple(self.data.shape)
211
212
    @property
213
    def spatial_shape(self) -> TypeTripletInt:
214
        return self.shape[1:]
215
216
    @property
217
    def orientation(self):
218
        return nib.aff2axcodes(self.affine)
219
220
    @property
221
    def spacing(self):
222
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
223
        return tuple(spacing)
224
225
    @property
226
    def memory(self):
227
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
228
229
    def get_spacing_string(self):
230
        strings = [f'{n:.2f}' for n in self.spacing]
231
        string = f'({", ".join(strings)})'
232
        return string
233
234
    @staticmethod
235
    def _parse_path(path: TypePath) -> Path:
236
        if path is None:
237
            return None
238
        try:
239
            path = Path(path).expanduser()
240
        except TypeError:
241
            message = f'Conversion to path not possible for variable: {path}'
242
            raise TypeError(message)
243
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
244
            raise FileNotFoundError(f'File not found: {path}')
245
        return path
246
247
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
248
        if tensor is None:
249
            return None
250
        if isinstance(tensor, np.ndarray):
251
            tensor = torch.from_numpy(tensor)
252
        tensor = self.parse_tensor_shape(tensor)
253
        return tensor
254
255
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
256
        return ensure_4d(tensor, self.channels_last, self.num_spatial_dims)
257
258
    @staticmethod
259
    def parse_affine(affine: np.ndarray) -> np.ndarray:
260
        if affine is None:
261
            return np.eye(4)
262
        if not isinstance(affine, np.ndarray):
263
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
264
        if affine.shape != (4, 4):
265
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
266
        return affine
267
268
    def _load(self) -> Tuple[torch.Tensor, np.ndarray]:
269
        r"""Load the image from disk.
270
271
        Returns:
272
            Tuple containing a 4D tensor of size :math:`(C, D, H, W)` and a 2D
273
            :math:`4 \times 4` affine matrix to convert voxel indices to world
274
            coordinates.
275
        """
276
        if self._loaded:
277
            return
278
        if self.path is None:
279
            return
280
        tensor, affine = read_image(self.path)
281
        tensor = self.parse_tensor_shape(tensor)
282
283
        if self.check_nans and torch.isnan(tensor).any():
284
            warnings.warn(f'NaNs found in file "{self.path}"')
285
        self[DATA] = tensor
286
        self[AFFINE] = affine
287
        self._loaded = True
288
289
    def save(self, path, squeeze=True, channels_last=True):
290
        """Save image to disk.
291
292
        Args:
293
            path: String or instance of :py:class:`pathlib.Path`.
294
            squeeze: If ``True``, the singleton dimensions will be removed
295
                before saving.
296
            channels_last: If ``True``, the channels will be saved in the last
297
                dimension.
298
        """
299
        write_image(
300
            self[DATA],
301
            self[AFFINE],
302
            path,
303
            squeeze=squeeze,
304
            channels_last=channels_last,
305
        )
306
307
    def is_2d(self) -> bool:
308
        return self.shape[-3] == 1
309
310
    def numpy(self) -> np.ndarray:
311
        """Get a NumPy array containing the image data."""
312
        return np.asarray(self)
313
314
    def as_sitk(self) -> sitk.Image:
315
        """Get the image as an instance of :py:class:`sitk.Image`."""
316
        return nib_to_sitk(self[DATA], self[AFFINE])
317
318
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
319
        """Get image center in RAS+ or LPS+ coordinates.
320
321
        Args:
322
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
323
                the first dimension grows towards the left, etc. Otherwise, the
324
                coordinates will be in RAS+ orientation.
325
        """
326
        size = np.array(self.spatial_shape)
327
        center_index = (size - 1) / 2
328
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
329
        if lps:
330
            return (-r, -a, s)
331
        else:
332
            return (r, a, s)
333
334
    def set_check_nans(self, check_nans):
335
        self.check_nans = check_nans
336
337
    def crop(self, index_ini, index_fin):
338
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
339
        new_affine = self.affine.copy()
340
        new_affine[:3, 3] = new_origin
341
        i0, j0, k0 = index_ini
342
        i1, j1, k1 = index_fin
343
        patch = self.data[0, i0:i1, j0:j1, k0:k1].clone()
344
        kwargs = dict(tensor=patch, affine=new_affine, type=self.type, path=self.path)
345
        for key, value in self.items():
346
            if key in PROTECTED_KEYS: continue
347
            kwargs[key] = value  # should I copy? deepcopy?
348
        return self.__class__(**kwargs)
349
350
351
class ScalarImage(Image):
352
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
353
354
    Example:
355
        >>> import torch
356
        >>> import torchio
357
        >>> image = torchio.ScalarImage('t1.nii.gz')  # loading from a file
358
        >>> image = torchio.ScalarImage(tensor=torch.rand(128, 128, 68))  # from tensor
359
        >>> data, affine = image.data, image.affine
360
        >>> affine.shape
361
        (4, 4)
362
        >>> image.data is image[torchio.DATA]
363
        True
364
        >>> image.data is image.tensor
365
        True
366
        >>> type(image.data)
367
        torch.Tensor
368
369
    See :py:class:`~torchio.Image` for more information.
370
371
    Raises:
372
        ValueError: A :py:attr:`type` is used for instantiation.
373
    """
374
    def __init__(self, *args, **kwargs):
375
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
376
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
377
        kwargs.update({'type': INTENSITY})
378
        super().__init__(*args, **kwargs)
379
380
381
class LabelMap(Image):
382
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
383
384
    Example:
385
        >>> import torch
386
        >>> import torchio
387
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
388
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
389
390
    See :py:class:`~torchio.data.image.Image` for more information.
391
392
    Raises:
393
        ValueError: If a value for :py:attr:`type` is given.
394
    """
395
    def __init__(self, *args, **kwargs):
396
        if 'type' in kwargs and kwargs['type'] != LABEL:
397
            raise ValueError('Type of LabelMap is always torchio.LABEL')
398
        kwargs.update({'type': LABEL})
399
        super().__init__(*args, **kwargs)
400