Passed
Push — master ( ddc71b...32d696 )
by Fernando
01:11
created

torchio.data.image.Image.get_center()   A

Complexity

Conditions 2

Size

Total Lines 15
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 2
dl 0
loc 15
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read. The data is expected to be 2D or
48
            3D, and may have multiple channels (see :attr:`num_spatial_dims` and
49
            :attr:`channels_last`).
50
        type: Type of image, such as :attr:`torchio.INTENSITY` or
51
            :attr:`torchio.LABEL`. This will be used by the transforms to
52
            decide whether to apply an operation, or which interpolation to use
53
            when resampling. For example, `preprocessing`_ and `augmentation`_
54
            intensity transforms will only be applied to images with type
55
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
56
            all types, and nearest neighbor interpolation is always used to
57
            resample images with type :attr:`torchio.LABEL`.
58
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
59
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
60
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
61
            :py:class:`torch.Tensor` or NumPy array with dimensions
62
            :math:`(C, D, H, W)`. If it is not 4D, TorchIO will try to guess
63
            the dimensions meanings. If 2D, the shape will be interpreted as
64
            :math:`(H, W)`. If 3D, the number of spatial dimensions should be
65
            determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims`
66
            is not given and the shape is 3 along the first or last dimensions,
67
            it will be interpreted as a multichannel 2D image. Otherwise, it
68
            be interpreted as a 3D image with a single channel.
69
        affine: If :attr:`path` is not given, :attr:`affine` must be a
70
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
71
            identity matrix.
72
        check_nans: If ``True``, issues a warning if NaNs are found
73
            in the image. If ``False``, images will not be checked for the
74
            presence of NaNs.
75
        num_spatial_dims: If ``2`` and the input tensor has 3 dimensions, it
76
            will be interpreted as a multichannel 2D image. If ``3`` and the
77
            input has 3 dimensions, it will be interpreted as a
78
            single-channel 3D volume.
79
        channels_last: If ``True``, the last dimension of the input will be
80
            interpreted as the channels. Defaults to ``True`` if :attr:`path` is
81
            given and ``False`` otherwise.
82
        **kwargs: Items that will be added to the image dictionary, e.g.
83
            acquisition parameters.
84
85
    Example:
86
        >>> import torch
87
        >>> import torchio
88
        >>> # Loading from a file
89
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
90
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
91
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
92
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
93
        >>> data, affine = image.data, image.affine
94
        >>> affine.shape
95
        (4, 4)
96
        >>> image.data is image[torchio.DATA]
97
        True
98
        >>> image.data is image.tensor
99
        True
100
        >>> type(image.data)
101
        torch.Tensor
102
103
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
104
    when needed.
105
106
    Example:
107
        >>> import torchio
108
        >>> image = torchio.Image('t1.nii.gz')
109
        >>> image  # not loaded yet
110
        Image(path: t1.nii.gz; type: intensity)
111
        >>> times_two = 2 * image.data  # data is loaded and cached here
112
        >>> image
113
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
114
        >>> image.save('doubled_image.nii.gz')
115
116
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
117
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
118
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
119
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
120
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
121
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
122
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
123
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
124
    """
125
    def __init__(
126
            self,
127
            path: Optional[TypePath] = None,
128
            type: str = None,
129
            tensor: Optional[TypeData] = None,
130
            affine: Optional[TypeData] = None,
131
            check_nans: bool = True,
132
            num_spatial_dims: Optional[int] = None,
133
            channels_last: Optional[bool] = None,
134
            **kwargs: Dict[str, Any],
135
            ):
136
        self.check_nans = check_nans
137
        self.num_spatial_dims = num_spatial_dims
138
139
        if type is None:
140
            warnings.warn(
141
                'Not specifying the image type is deprecated and will be'
142
                ' mandatory in the future. You can probably use ScalarImage or'
143
                ' LabelMap instead'
144
            )
145
            type = INTENSITY
146
147
        if path is None and tensor is None:
148
            raise ValueError('A value for path or tensor must be given')
149
        self._loaded = False
150
151
        # Number of channels are typically stored in the last dimensions in disk
152
        # But if a tensor is given, the channels should be in the first dim
153
        if channels_last is None:
154
            channels_last = path is not None
155
        self.channels_last = channels_last
156
157
        tensor = self.parse_tensor(tensor)
158
        affine = self.parse_affine(affine)
159
        if tensor is not None:
160
            self[DATA] = tensor
161
            self[AFFINE] = affine
162
            self._loaded = True
163
        for key in PROTECTED_KEYS:
164
            if key in kwargs:
165
                message = f'Key "{key}" is reserved. Use a different one'
166
                raise ValueError(message)
167
168
        super().__init__(**kwargs)
169
        self.path = self._parse_path(path)
170
        self[PATH] = '' if self.path is None else str(self.path)
171
        self[STEM] = '' if self.path is None else get_stem(self.path)
172
        self[TYPE] = type
173
174
    def __repr__(self):
175
        properties = []
176
        if self._loaded:
177
            properties.extend([
178
                f'shape: {self.shape}',
179
                f'spacing: {self.get_spacing_string()}',
180
                f'orientation: {"".join(self.orientation)}+',
181
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
182
            ])
183
        else:
184
            properties.append(f'path: "{self.path}"')
185
        properties.append(f'type: {self.type}')
186
        properties = '; '.join(properties)
187
        string = f'{self.__class__.__name__}({properties})'
188
        return string
189
190
    def __getitem__(self, item):
191
        if item in (DATA, AFFINE):
192
            if item not in self:
193
                self._load()
194
        return super().__getitem__(item)
195
196
    def __array__(self):
197
        return self[DATA].numpy()
198
199
    def __copy__(self):
200
        kwargs = dict(
201
            tensor=self.data,
202
            affine=self.affine,
203
            type=self.type,
204
            path=self.path,
205
            channels_last=False,
206
        )
207
        for key, value in self.items():
208
            if key in PROTECTED_KEYS: continue
209
            kwargs[key] = value  # should I copy? deepcopy?
210
        return self.__class__(**kwargs)
211
212
    @property
213
    def data(self):
214
        return self[DATA]
215
216
    @property
217
    def tensor(self):
218
        return self.data
219
220
    @property
221
    def affine(self):
222
        return self[AFFINE]
223
224
    @property
225
    def type(self):
226
        return self[TYPE]
227
228
    @property
229
    def shape(self) -> Tuple[int, int, int, int]:
230
        return tuple(self.data.shape)
231
232
    @property
233
    def spatial_shape(self) -> TypeTripletInt:
234
        return self.shape[1:]
235
236
    @property
237
    def orientation(self):
238
        return nib.aff2axcodes(self.affine)
239
240
    @property
241
    def spacing(self):
242
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
243
        return tuple(spacing)
244
245
    @property
246
    def memory(self):
247
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
248
249
    def get_spacing_string(self):
250
        strings = [f'{n:.2f}' for n in self.spacing]
251
        string = f'({", ".join(strings)})'
252
        return string
253
254
    def get_bounds(self):
255
        """Get image bounds in mm."""
256
        first_index = 3 * (-0.5,)
257
        last_index = np.array(self.spatial_shape) - 0.5
258
        first_point = nib.affines.apply_affine(self.affine, first_index)
259
        last_point = nib.affines.apply_affine(self.affine, last_index)
260
        array = np.array((first_point, last_point))
261
        bounds_x, bounds_y, bounds_z = array.T.tolist()
262
        return bounds_x, bounds_y, bounds_z
263
264
    @staticmethod
265
    def _parse_path(path: TypePath) -> Path:
266
        if path is None:
267
            return None
268
        try:
269
            path = Path(path).expanduser()
270
        except TypeError:
271
            message = f'Conversion to path not possible for variable: {path}'
272
            raise TypeError(message)
273
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
274
            raise FileNotFoundError(f'File not found: {path}')
275
        return path
276
277
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
278
        if tensor is None:
279
            return None
280
        if isinstance(tensor, np.ndarray):
281
            tensor = torch.from_numpy(tensor.astype(np.float32))
282
        elif isinstance(tensor, torch.Tensor):
283
            tensor = tensor.float()
284
        tensor = self.parse_tensor_shape(tensor)
285
        if self.check_nans and torch.isnan(tensor).any():
286
            warnings.warn(f'NaNs found in tensor')
287
        return tensor
288
289
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
290
        return ensure_4d(tensor, self.channels_last, self.num_spatial_dims)
291
292
    @staticmethod
293
    def parse_affine(affine: np.ndarray) -> np.ndarray:
294
        if affine is None:
295
            return np.eye(4)
296
        if not isinstance(affine, np.ndarray):
297
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
298
        if affine.shape != (4, 4):
299
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
300
        return affine
301
302
    def _load(self) -> Tuple[torch.Tensor, np.ndarray]:
303
        r"""Load the image from disk.
304
305
        Returns:
306
            Tuple containing a 4D tensor of size :math:`(C, D, H, W)` and a 2D
307
            :math:`4 \times 4` affine matrix to convert voxel indices to world
308
            coordinates.
309
        """
310
        if self._loaded:
311
            return
312
        tensor, affine = read_image(self.path)
313
        tensor = self.parse_tensor_shape(tensor)
314
315
        if self.check_nans and torch.isnan(tensor).any():
316
            warnings.warn(f'NaNs found in file "{self.path}"')
317
        self[DATA] = tensor
318
        self[AFFINE] = affine
319
        self._loaded = True
320
321
    def save(self, path, squeeze=True, channels_last=True):
322
        """Save image to disk.
323
324
        Args:
325
            path: String or instance of :py:class:`pathlib.Path`.
326
            squeeze: If ``True``, the singleton dimensions will be removed
327
                before saving.
328
            channels_last: If ``True``, the channels will be saved in the last
329
                dimension.
330
        """
331
        write_image(
332
            self[DATA],
333
            self[AFFINE],
334
            path,
335
            squeeze=squeeze,
336
            channels_last=channels_last,
337
        )
338
339
    def is_2d(self) -> bool:
340
        return self.shape[-3] == 1
341
342
    def numpy(self) -> np.ndarray:
343
        """Get a NumPy array containing the image data."""
344
        return np.asarray(self)
345
346
    def as_sitk(self) -> sitk.Image:
347
        """Get the image as an instance of :py:class:`sitk.Image`."""
348
        return nib_to_sitk(self[DATA], self[AFFINE])
349
350
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
351
        """Get image center in RAS+ or LPS+ coordinates.
352
353
        Args:
354
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
355
                the first dimension grows towards the left, etc. Otherwise, the
356
                coordinates will be in RAS+ orientation.
357
        """
358
        size = np.array(self.spatial_shape)
359
        center_index = (size - 1) / 2
360
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
361
        if lps:
362
            return (-r, -a, s)
363
        else:
364
            return (r, a, s)
365
366
    def set_check_nans(self, check_nans: bool):
367
        self.check_nans = check_nans
368
369
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
370
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
371
        new_affine = self.affine.copy()
372
        new_affine[:3, 3] = new_origin
373
        i0, j0, k0 = index_ini
374
        i1, j1, k1 = index_fin
375
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
376
        kwargs = dict(
377
            tensor=patch,
378
            affine=new_affine,
379
            type=self.type,
380
            path=self.path,
381
            channels_last=False,
382
        )
383
        for key, value in self.items():
384
            if key in PROTECTED_KEYS: continue
385
            kwargs[key] = value  # should I copy? deepcopy?
386
        return self.__class__(**kwargs)
387
388
389
class ScalarImage(Image):
390
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
391
392
    Example:
393
        >>> import torch
394
        >>> import torchio
395
        >>> image = torchio.ScalarImage('t1.nii.gz')  # loading from a file
396
        >>> image = torchio.ScalarImage(tensor=torch.rand(128, 128, 68))  # from tensor
397
        >>> data, affine = image.data, image.affine
398
        >>> affine.shape
399
        (4, 4)
400
        >>> image.data is image[torchio.DATA]
401
        True
402
        >>> image.data is image.tensor
403
        True
404
        >>> type(image.data)
405
        torch.Tensor
406
407
    See :py:class:`~torchio.Image` for more information.
408
409
    Raises:
410
        ValueError: A :py:attr:`type` is used for instantiation.
411
    """
412
    def __init__(self, *args, **kwargs):
413
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
414
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
415
        kwargs.update({'type': INTENSITY})
416
        super().__init__(*args, **kwargs)
417
418
419
class LabelMap(Image):
420
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
421
422
    Example:
423
        >>> import torch
424
        >>> import torchio
425
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
426
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
427
428
    See :py:class:`~torchio.data.image.Image` for more information.
429
430
    Raises:
431
        ValueError: If a value for :py:attr:`type` is given.
432
    """
433
    def __init__(self, *args, **kwargs):
434
        if 'type' in kwargs and kwargs['type'] != LABEL:
435
            raise ValueError('Type of LabelMap is always torchio.LABEL')
436
        kwargs.update({'type': LABEL})
437
        super().__init__(*args, **kwargs)
438