Passed
Push — master ( 53ab14...c2608f )
by Fernando
01:07
created

torchio.data.image.Image.height()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 1
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file or sequence of paths to files that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read. The data is expected to be 2D or
48
            3D, and may have multiple channels (see :attr:`num_spatial_dims` and
49
            :attr:`channels_last`). If a sequence of paths is given, data
50
            will be concatenated on the channel dimension so spatial
51
            dimensions must match.
52
        type: Type of image, such as :attr:`torchio.INTENSITY` or
53
            :attr:`torchio.LABEL`. This will be used by the transforms to
54
            decide whether to apply an operation, or which interpolation to use
55
            when resampling. For example, `preprocessing`_ and `augmentation`_
56
            intensity transforms will only be applied to images with type
57
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
58
            all types, and nearest neighbor interpolation is always used to
59
            resample images with type :attr:`torchio.LABEL`.
60
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
61
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
62
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
63
            :py:class:`torch.Tensor` or NumPy array with dimensions
64
            :math:`(C, H, W, D)`. If it is not 4D, TorchIO will try to guess
65
            the dimensions meanings. If 2D, the shape will be interpreted as
66
            :math:`(H, W)`. If 3D, the number of spatial dimensions should be
67
            determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims`
68
            is not given and the shape is 3 along the first or last dimensions,
69
            it will be interpreted as a multichannel 2D image. Otherwise, it
70
            be interpreted as a 3D image with a single channel.
71
        affine: If :attr:`path` is not given, :attr:`affine` must be a
72
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
73
            identity matrix.
74
        check_nans: If ``True``, issues a warning if NaNs are found
75
            in the image. If ``False``, images will not be checked for the
76
            presence of NaNs.
77
        num_spatial_dims: If ``2`` and the input tensor has 3 dimensions, it
78
            will be interpreted as a multichannel 2D image. If ``3`` and the
79
            input has 3 dimensions, it will be interpreted as a
80
            single-channel 3D volume.
81
        channels_last: If ``True``, the last dimension of the input will be
82
            interpreted as the channels. Defaults to ``True`` if :attr:`path` is
83
            given and ``False`` otherwise.
84
        **kwargs: Items that will be added to the image dictionary, e.g.
85
            acquisition parameters.
86
87
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
88
    when needed.
89
90
    Example:
91
        >>> import torchio
92
        >>> image = torchio.ScalarImage('t1.nii.gz')  # subclass of Image
93
        >>> image  # not loaded yet
94
        ScalarImage(path: t1.nii.gz; type: intensity)
95
        >>> times_two = 2 * image.data  # data is loaded and cached here
96
        >>> image
97
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
98
        >>> image.save('doubled_image.nii.gz')
99
100
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
101
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
102
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
103
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
104
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
105
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
106
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
107
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
108
    """
109
    def __init__(
110
            self,
111
            path: Union[TypePath, Sequence[TypePath], None] = None,
112
            type: str = None,
113
            tensor: Optional[TypeData] = None,
114
            affine: Optional[TypeData] = None,
115
            check_nans: bool = True,
116
            num_spatial_dims: Optional[int] = None,
117
            channels_last: Optional[bool] = None,
118
            **kwargs: Dict[str, Any],
119
            ):
120
        self.check_nans = check_nans
121
        self.num_spatial_dims = num_spatial_dims
122
123
        if type is None:
124
            warnings.warn(
125
                'Not specifying the image type is deprecated and will be'
126
                ' mandatory in the future. You can probably use ScalarImage or'
127
                ' LabelMap instead'
128
            )
129
            type = INTENSITY
130
131
        if path is None and tensor is None:
132
            raise ValueError('A value for path or tensor must be given')
133
        self._loaded = False
134
135
        # Number of channels are typically stored in the last dimensions in disk
136
        # But if a tensor is given, the channels should be in the first dim
137
        if channels_last is None:
138
            channels_last = path is not None
139
        self.channels_last = channels_last
140
141
        tensor = self.parse_tensor(tensor)
142
        affine = self.parse_affine(affine)
143
        if tensor is not None:
144
            self[DATA] = tensor
145
            self[AFFINE] = affine
146
            self._loaded = True
147
        for key in PROTECTED_KEYS:
148
            if key in kwargs:
149
                message = f'Key "{key}" is reserved. Use a different one'
150
                raise ValueError(message)
151
152
        super().__init__(**kwargs)
153
        self.path = self._parse_path(path)
154
        self[PATH] = '' if self.path is None else str(self.path)
155
        self[STEM] = '' if self.path is None else get_stem(self.path)
156
        self[TYPE] = type
157
158
    def __repr__(self):
159
        properties = []
160
        if self._loaded:
161
            properties.extend([
162
                f'shape: {self.shape}',
163
                f'spacing: {self.get_spacing_string()}',
164
                f'orientation: {"".join(self.orientation)}+',
165
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
166
            ])
167
        else:
168
            properties.append(f'path: "{self.path}"')
169
        properties.append(f'type: {self.type}')
170
        properties = '; '.join(properties)
171
        string = f'{self.__class__.__name__}({properties})'
172
        return string
173
174
    def __getitem__(self, item):
175
        if item in (DATA, AFFINE):
176
            if item not in self:
177
                self.load()
178
        return super().__getitem__(item)
179
180
    def __array__(self):
181
        return self[DATA].numpy()
182
183
    def __copy__(self):
184
        kwargs = dict(
185
            tensor=self.data,
186
            affine=self.affine,
187
            type=self.type,
188
            path=self.path,
189
            channels_last=False,
190
        )
191
        for key, value in self.items():
192
            if key in PROTECTED_KEYS: continue
193
            kwargs[key] = value  # should I copy? deepcopy?
194
        return self.__class__(**kwargs)
195
196
    @property
197
    def data(self):
198
        return self[DATA]
199
200
    @property
201
    def tensor(self):
202
        return self.data
203
204
    @property
205
    def affine(self):
206
        return self[AFFINE]
207
208
    @property
209
    def type(self):
210
        return self[TYPE]
211
212
    @property
213
    def shape(self) -> Tuple[int, int, int, int]:
214
        return tuple(self.data.shape)
215
216
    @property
217
    def spatial_shape(self) -> TypeTripletInt:
218
        return self.shape[1:]
219
220
    def check_is_2d(self):
221
        if not self.is_2d():
222
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
223
            raise RuntimeError(message)
224
225
    @property
226
    def height(self) -> int:
227
        self.check_is_2d()
228
        return self.spatial_shape[0]
229
230
    @property
231
    def width(self) -> int:
232
        self.check_is_2d()
233
        return self.spatial_shape[1]
234
235
    @property
236
    def orientation(self):
237
        return nib.aff2axcodes(self.affine)
238
239
    @property
240
    def spacing(self):
241
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
242
        return tuple(spacing)
243
244
    @property
245
    def memory(self):
246
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
247
248
    def axis_name_to_index(self, axis: str):
249
        """Convert an axis name to an axis index.
250
251
        Args:
252
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
253
            ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case versions
254
            and first letters are also valid, as only the first letter will be
255
            used.
256
257
        .. note:: If you are working with animals, you should probably use
258
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
259
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
260
            respectively.
261
        """
262
        if not isinstance(axis, str):
263
            raise ValueError('Axis must be a string')
264
        axis = axis[0].upper()
265
266
        # Generally, TorchIO tensors are (C, H, W, D)
267
        if axis == 'H':
268
            return 1
269
        elif axis == 'W':
270
            return 2
271
        else:
272
            try:
273
                index = self.orientation.index(axis)
274
            except ValueError:
275
                index = self.orientation.index(self.flip_axis(axis))
276
            # Return negative indices so that it does not matter whether we
277
            # refer to spatial dimensions or not
278
            index = -4 + index
279
            return index
280
281
    @staticmethod
282
    def flip_axis(axis):
283
        if axis == 'R': return 'L'
284
        elif axis == 'L': return 'R'
285
        elif axis == 'A': return 'P'
286
        elif axis == 'P': return 'A'
287
        elif axis == 'I': return 'S'
288
        elif axis == 'S': return 'I'
289
        else:
290
            message = (
291
                f'Axis not understood. Please use a value in {tuple("LRAPIS")}'
292
            )
293
            raise ValueError(message)
294
295
    def get_spacing_string(self):
296
        strings = [f'{n:.2f}' for n in self.spacing]
297
        string = f'({", ".join(strings)})'
298
        return string
299
300
    def get_bounds(self):
301
        """Get image bounds in mm."""
302
        first_index = 3 * (-0.5,)
303
        last_index = np.array(self.spatial_shape) - 0.5
304
        first_point = nib.affines.apply_affine(self.affine, first_index)
305
        last_point = nib.affines.apply_affine(self.affine, last_index)
306
        array = np.array((first_point, last_point))
307
        bounds_x, bounds_y, bounds_z = array.T.tolist()
308
        return bounds_x, bounds_y, bounds_z
309
310
    @staticmethod
311
    def _parse_single_path(
312
            path: TypePath
313
            ) -> Path:
314
        try:
315
            path = Path(path).expanduser()
316
        except TypeError:
317
            message = (
318
                f'Expected type str or Path but found {path} with '
319
                f'{type(path)} instead'
320
            )
321
            raise TypeError(message)
322
        except RuntimeError:
323
            message = (
324
                f'Conversion to path not possible for variable: {path}'
325
            )
326
            raise RuntimeError(message)
327
328
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
329
            raise FileNotFoundError(f'File not found: {path}')
330
        return path
331
332
    def _parse_path(
333
            self,
334
            path: Union[TypePath, Sequence[TypePath]]
335
            ) -> Union[Path, List[Path]]:
336
        if path is None:
337
            return None
338
        if isinstance(path, (str, Path)):
339
            return self._parse_single_path(path)
340
        else:
341
            return [self._parse_single_path(p) for p in path]
342
343
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
344
        if tensor is None:
345
            return None
346
        if isinstance(tensor, np.ndarray):
347
            tensor = torch.from_numpy(tensor.astype(np.float32))
348
        elif isinstance(tensor, torch.Tensor):
349
            tensor = tensor.float()
350
        tensor = self.parse_tensor_shape(tensor)
351
        if self.check_nans and torch.isnan(tensor).any():
352
            warnings.warn(f'NaNs found in tensor')
353
        return tensor
354
355
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
356
        return ensure_4d(tensor, self.channels_last, self.num_spatial_dims)
357
358
    @staticmethod
359
    def parse_affine(affine: np.ndarray) -> np.ndarray:
360
        if affine is None:
361
            return np.eye(4)
362
        if not isinstance(affine, np.ndarray):
363
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
364
        if affine.shape != (4, 4):
365
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
366
        return affine
367
368
    def load(self) -> None:
369
        r"""Load the image from disk.
370
371
        Returns:
372
            Tuple containing a 4D tensor of size :math:`(C, H, W, D)` and a 2D
373
            :math:`4 \times 4` affine matrix to convert voxel indices to world
374
            coordinates.
375
        """
376
        if self._loaded:
377
            return
378
379
        paths = self.path if isinstance(self.path, list) else [self.path]
380
        tensor, affine = read_image(paths[0])
381
        tensor = self.parse_tensor_shape(tensor)
382
383
        if self.check_nans and torch.isnan(tensor).any():
384
            warnings.warn(f'NaNs found in file "{paths[0]}"')
385
386
        tensors = [tensor]
387
        for path in paths[1:]:
388
            new_tensor, new_affine = read_image(path)
389
            new_tensor = self.parse_tensor_shape(new_tensor)
390
391
            if self.check_nans and torch.isnan(tensor).any():
392
                warnings.warn(f'NaNs found in file "{path}"')
393
394
            if not np.array_equal(affine, new_affine):
395
                message = (
396
                    'Files have different affine matrices.'
397
                    f'\nMatrix of {paths[0]}:'
398
                    f'\n{affine}'
399
                    f'\nMatrix of {path}:'
400
                    f'\n{new_affine}'
401
                )
402
                warnings.warn(message, RuntimeWarning)
403
404
            if not tensor.shape[1:] == new_tensor.shape[1:]:
405
                message = (
406
                    f'Files shape do not match, found {tensor.shape}'
407
                    f'and {new_tensor.shape}'
408
                )
409
                RuntimeError(message)
410
411
            tensors.append(new_tensor)
412
413
        tensor = torch.cat(tensors)
414
415
        self[DATA] = tensor
416
        self[AFFINE] = affine
417
        self._loaded = True
418
419
    def save(self, path, squeeze=True, channels_last=True):
420
        """Save image to disk.
421
422
        Args:
423
            path: String or instance of :py:class:`pathlib.Path`.
424
            squeeze: If ``True``, the singleton dimensions will be removed
425
                before saving.
426
            channels_last: If ``True``, the channels will be saved in the last
427
                dimension.
428
        """
429
        write_image(
430
            self[DATA],
431
            self[AFFINE],
432
            path,
433
            squeeze=squeeze,
434
            channels_last=channels_last,
435
        )
436
437
    def is_2d(self) -> bool:
438
        return self.shape[-1] == 1
439
440
    def numpy(self) -> np.ndarray:
441
        """Get a NumPy array containing the image data."""
442
        return np.asarray(self)
443
444
    def as_sitk(self) -> sitk.Image:
445
        """Get the image as an instance of :py:class:`sitk.Image`."""
446
        return nib_to_sitk(self[DATA], self[AFFINE])
447
448
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
449
        """Get image center in RAS+ or LPS+ coordinates.
450
451
        Args:
452
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
453
                the first dimension grows towards the left, etc. Otherwise, the
454
                coordinates will be in RAS+ orientation.
455
        """
456
        size = np.array(self.spatial_shape)
457
        center_index = (size - 1) / 2
458
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
459
        if lps:
460
            return (-r, -a, s)
461
        else:
462
            return (r, a, s)
463
464
    def set_check_nans(self, check_nans: bool):
465
        self.check_nans = check_nans
466
467
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
468
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
469
        new_affine = self.affine.copy()
470
        new_affine[:3, 3] = new_origin
471
        i0, j0, k0 = index_ini
472
        i1, j1, k1 = index_fin
473
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
474
        kwargs = dict(
475
            tensor=patch,
476
            affine=new_affine,
477
            type=self.type,
478
            path=self.path,
479
            channels_last=False,
480
        )
481
        for key, value in self.items():
482
            if key in PROTECTED_KEYS: continue
483
            kwargs[key] = value  # should I copy? deepcopy?
484
        return self.__class__(**kwargs)
485
486
487
class ScalarImage(Image):
488
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
489
490
    Example:
491
        >>> import torch
492
        >>> import torchio
493
        >>> # Loading from a file
494
        >>> t1_image = torchio.ScalarImage('t1.nii.gz')
495
        >>> dmri = torchio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
496
        >>> image = torchio.ScalarImage('safe_image.nrrd', check_nans=False)
497
        >>> data, affine = image.data, image.affine
498
        >>> affine.shape
499
        (4, 4)
500
        >>> image.data is image[torchio.DATA]
501
        True
502
        >>> image.data is image.tensor
503
        True
504
        >>> type(image.data)
505
        torch.Tensor
506
507
    See :py:class:`~torchio.Image` for more information.
508
509
    Raises:
510
        ValueError: A :py:attr:`type` is used for instantiation.
511
    """
512
    def __init__(self, *args, **kwargs):
513
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
514
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
515
        kwargs.update({'type': INTENSITY})
516
        super().__init__(*args, **kwargs)
517
518
519
class LabelMap(Image):
520
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
521
522
    Example:
523
        >>> import torch
524
        >>> import torchio
525
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
526
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
527
        >>> tpm = torchio.LabelMap(                     # loading from files
528
        ...     'gray_matter.nii.gz',
529
        ...     'white_matter.nii.gz',
530
        ...     'csf.nii.gz',
531
        ... )
532
533
    See :py:class:`~torchio.data.image.Image` for more information.
534
535
    Raises:
536
        ValueError: If a value for :py:attr:`type` is given.
537
    """
538
    def __init__(self, *args, **kwargs):
539
        if 'type' in kwargs and kwargs['type'] != LABEL:
540
            raise ValueError('Type of LabelMap is always torchio.LABEL')
541
        kwargs.update({'type': LABEL})
542
        super().__init__(*args, **kwargs)
543