Passed
Push — master ( 85bce8...47d3da )
by Fernando
01:13
created

torchio.data.image.LabelMap.__init__()   A

Complexity

Conditions 3

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 5
nop 3
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file or sequence of paths to files that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read.
48
            If a sequence of paths is given, data
49
            will be concatenated on the channel dimension so spatial
50
            dimensions must match.
51
        type: Type of image, such as :attr:`torchio.INTENSITY` or
52
            :attr:`torchio.LABEL`. This will be used by the transforms to
53
            decide whether to apply an operation, or which interpolation to use
54
            when resampling. For example, `preprocessing`_ and `augmentation`_
55
            intensity transforms will only be applied to images with type
56
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
57
            all types, and nearest neighbor interpolation is always used to
58
            resample images with type :attr:`torchio.LABEL`.
59
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
60
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
61
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
62
            :py:class:`torch.Tensor` or NumPy array with dimensions
63
            :math:`(C, W, H, D)`.
64
        affine: If :attr:`path` is not given, :attr:`affine` must be a
65
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
66
            identity matrix.
67
        check_nans: If ``True``, issues a warning if NaNs are found
68
            in the image. If ``False``, images will not be checked for the
69
            presence of NaNs.
70
        **kwargs: Items that will be added to the image dictionary, e.g.
71
            acquisition parameters.
72
73
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
74
    when needed.
75
76
    Example:
77
        >>> import torchio
78
        >>> image = torchio.ScalarImage('t1.nii.gz')  # subclass of Image
79
        >>> image  # not loaded yet
80
        ScalarImage(path: t1.nii.gz; type: intensity)
81
        >>> times_two = 2 * image.data  # data is loaded and cached here
82
        >>> image
83
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
84
        >>> image.save('doubled_image.nii.gz')
85
86
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
87
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
88
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
89
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
90
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
91
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
92
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
93
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
94
    """
95
    def __init__(
96
            self,
97
            path: Union[TypePath, Sequence[TypePath], None] = None,
98
            type: str = None,
99
            tensor: Optional[TypeData] = None,
100
            affine: Optional[TypeData] = None,
101
            check_nans: bool = False,  # removed by ITK by default
102
            channels_last: bool = False,
103
            **kwargs: Dict[str, Any],
104
            ):
105
        self.check_nans = check_nans
106
        self.channels_last = channels_last
107
108
        if type is None:
109
            warnings.warn(
110
                'Not specifying the image type is deprecated and will be'
111
                ' mandatory in the future. You can probably use ScalarImage or'
112
                ' LabelMap instead'
113
            )
114
            type = INTENSITY
115
116
        if path is None and tensor is None:
117
            raise ValueError('A value for path or tensor must be given')
118
        self._loaded = False
119
120
        tensor = self.parse_tensor(tensor)
121
        affine = self.parse_affine(affine)
122
        if tensor is not None:
123
            self[DATA] = tensor
124
            self[AFFINE] = affine
125
            self._loaded = True
126
        for key in PROTECTED_KEYS:
127
            if key in kwargs:
128
                message = f'Key "{key}" is reserved. Use a different one'
129
                raise ValueError(message)
130
131
        super().__init__(**kwargs)
132
        self.path = self._parse_path(path)
133
134
        self[PATH] = '' if self.path is None else str(self.path)
135
        self[STEM] = '' if self.path is None else get_stem(self.path)
136
        self[TYPE] = type
137
138
    def __repr__(self):
139
        properties = []
140
        if self._loaded:
141
            properties.extend([
142
                f'shape: {self.shape}',
143
                f'spacing: {self.get_spacing_string()}',
144
                f'orientation: {"".join(self.orientation)}+',
145
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
146
            ])
147
        else:
148
            properties.append(f'path: "{self.path}"')
149
        properties.append(f'type: {self.type}')
150
        properties = '; '.join(properties)
151
        string = f'{self.__class__.__name__}({properties})'
152
        return string
153
154
    def __getitem__(self, item):
155
        if item in (DATA, AFFINE):
156
            if item not in self:
157
                self.load()
158
        return super().__getitem__(item)
159
160
    def __array__(self):
161
        return self[DATA].numpy()
162
163
    def __copy__(self):
164
        kwargs = dict(
165
            tensor=self.data,
166
            affine=self.affine,
167
            type=self.type,
168
            path=self.path,
169
        )
170
        for key, value in self.items():
171
            if key in PROTECTED_KEYS: continue
172
            kwargs[key] = value  # should I copy? deepcopy?
173
        return self.__class__(**kwargs)
174
175
    @property
176
    def data(self):
177
        return self[DATA]
178
179
    @property
180
    def tensor(self):
181
        return self.data
182
183
    @property
184
    def affine(self):
185
        return self[AFFINE]
186
187
    @property
188
    def type(self):
189
        return self[TYPE]
190
191
    @property
192
    def shape(self) -> Tuple[int, int, int, int]:
193
        return tuple(self.data.shape)
194
195
    @property
196
    def spatial_shape(self) -> TypeTripletInt:
197
        return self.shape[1:]
198
199
    def check_is_2d(self):
200
        if not self.is_2d():
201
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
202
            raise RuntimeError(message)
203
204
    @property
205
    def height(self) -> int:
206
        self.check_is_2d()
207
        return self.spatial_shape[1]
208
209
    @property
210
    def width(self) -> int:
211
        self.check_is_2d()
212
        return self.spatial_shape[0]
213
214
    @property
215
    def orientation(self):
216
        return nib.aff2axcodes(self.affine)
217
218
    @property
219
    def spacing(self):
220
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
221
        return tuple(spacing)
222
223
    @property
224
    def memory(self):
225
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
226
227
    def axis_name_to_index(self, axis: str):
228
        """Convert an axis name to an axis index.
229
230
        Args:
231
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
232
            ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case versions
233
            and first letters are also valid, as only the first letter will be
234
            used.
235
236
        .. note:: If you are working with animals, you should probably use
237
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
238
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
239
            respectively.
240
241
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
242
            ``'Left'`` and ``'Right'``.
243
        """
244
        # Top and bottom are used for the vertical 2D axis as the use of
245
        # Height vs Horizontal might be ambiguous
246
247
        if not isinstance(axis, str):
248
            raise ValueError('Axis must be a string')
249
        axis = axis[0].upper()
250
251
        # Generally, TorchIO tensors are (C, W, H, D)
252
        if axis in 'TB':  # Top, Bottom
253
            return -2
254
        else:
255
            try:
256
                index = self.orientation.index(axis)
257
            except ValueError:
258
                index = self.orientation.index(self.flip_axis(axis))
259
            # Return negative indices so that it does not matter whether we
260
            # refer to spatial dimensions or not
261
            index = -3 + index
262
            return index
263
264
    # flake8: noqa: E701
265
    @staticmethod
266
    def flip_axis(axis):
267
        if axis == 'R': return 'L'
268
        elif axis == 'L': return 'R'
269
        elif axis == 'A': return 'P'
270
        elif axis == 'P': return 'A'
271
        elif axis == 'I': return 'S'
272
        elif axis == 'S': return 'I'
273
        else:
274
            values = ', '.join('LRPAISTB')
275
            message = f'Axis not understood. Please use one of: {values}'
276
            raise ValueError(message)
277
278
    def get_spacing_string(self):
279
        strings = [f'{n:.2f}' for n in self.spacing]
280
        string = f'({", ".join(strings)})'
281
        return string
282
283
    def get_bounds(self):
284
        """Get image bounds in mm."""
285
        first_index = 3 * (-0.5,)
286
        last_index = np.array(self.spatial_shape) - 0.5
287
        first_point = nib.affines.apply_affine(self.affine, first_index)
288
        last_point = nib.affines.apply_affine(self.affine, last_index)
289
        array = np.array((first_point, last_point))
290
        bounds_x, bounds_y, bounds_z = array.T.tolist()
291
        return bounds_x, bounds_y, bounds_z
292
293
    @staticmethod
294
    def _parse_single_path(
295
            path: TypePath
296
            ) -> Path:
297
        try:
298
            path = Path(path).expanduser()
299
        except TypeError:
300
            message = (
301
                f'Expected type str or Path but found {path} with '
302
                f'{type(path)} instead'
303
            )
304
            raise TypeError(message)
305
        except RuntimeError:
306
            message = (
307
                f'Conversion to path not possible for variable: {path}'
308
            )
309
            raise RuntimeError(message)
310
311
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
312
            raise FileNotFoundError(f'File not found: {path}')
313
        return path
314
315
    def _parse_path(
316
            self,
317
            path: Union[TypePath, Sequence[TypePath]]
318
            ) -> Union[Path, List[Path]]:
319
        if path is None:
320
            return None
321
        if isinstance(path, (str, Path)):
322
            return self._parse_single_path(path)
323
        else:
324
            return [self._parse_single_path(p) for p in path]
325
326
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
327
        if tensor is None:
328
            return None
329
        if isinstance(tensor, np.ndarray):
330
            tensor = torch.from_numpy(tensor.astype(np.float32))
331
        elif isinstance(tensor, torch.Tensor):
332
            tensor = tensor.float()
333
        if tensor.ndim != 4:
334
            raise ValueError('Input tensor must be 4D')
335
        if self.check_nans and torch.isnan(tensor).any():
336
            warnings.warn(f'NaNs found in tensor')
337
        return tensor
338
339
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
340
        return ensure_4d(tensor)
341
342
    @staticmethod
343
    def parse_affine(affine: np.ndarray) -> np.ndarray:
344
        if affine is None:
345
            return np.eye(4)
346
        if not isinstance(affine, np.ndarray):
347
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
348
        if affine.shape != (4, 4):
349
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
350
        return affine
351
352
    def load(self) -> None:
353
        r"""Load the image from disk.
354
355
        Returns:
356
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
357
            :math:`4 \times 4` affine matrix to convert voxel indices to world
358
            coordinates.
359
        """
360
        if self._loaded:
361
            return
362
        paths = self.path if isinstance(self.path, list) else [self.path]
363
        tensor, affine = self.read_and_check(paths[0])
364
        tensors = [tensor]
365
        for path in paths[1:]:
366
            new_tensor, new_affine = self.read_and_check(path)
367
            if not np.array_equal(affine, new_affine):
368
                message = (
369
                    'Files have different affine matrices.'
370
                    f'\nMatrix of {paths[0]}:'
371
                    f'\n{affine}'
372
                    f'\nMatrix of {path}:'
373
                    f'\n{new_affine}'
374
                )
375
                warnings.warn(message, RuntimeWarning)
376
            if not tensor.shape[1:] == new_tensor.shape[1:]:
377
                message = (
378
                    f'Files shape do not match, found {tensor.shape}'
379
                    f'and {new_tensor.shape}'
380
                )
381
                RuntimeError(message)
382
            tensors.append(new_tensor)
383
        tensor = torch.cat(tensors)
384
        self[DATA] = tensor
385
        self[AFFINE] = affine
386
        self._loaded = True
387
388
    def read_and_check(self, path):
389
        tensor, affine = read_image(path)
390
        tensor = self.parse_tensor_shape(tensor)
391
        if self.channels_last:
392
            tensor = tensor.permute(3, 0, 1, 2)
393
        if self.check_nans and torch.isnan(tensor).any():
394
            warnings.warn(f'NaNs found in file "{path}"')
395
        return tensor, affine
396
397
    def save(self, path: TypePath, squeeze: bool = True):
398
        """Save image to disk.
399
400
        Args:
401
            path: String or instance of :py:class:`pathlib.Path`.
402
            squeeze: If ``True``, the singleton dimensions will be removed
403
                before saving.
404
        """
405
        write_image(
406
            self[DATA],
407
            self[AFFINE],
408
            path,
409
            squeeze=squeeze,
410
        )
411
412
    def is_2d(self) -> bool:
413
        return self.shape[-1] == 1
414
415
    def numpy(self) -> np.ndarray:
416
        """Get a NumPy array containing the image data."""
417
        return np.asarray(self)
418
419
    def as_sitk(self, **kwargs) -> sitk.Image:
420
        """Get the image as an instance of :py:class:`sitk.Image`."""
421
        return nib_to_sitk(self[DATA], self[AFFINE], **kwargs)
422
423
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
424
        """Get image center in RAS+ or LPS+ coordinates.
425
426
        Args:
427
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
428
                the first dimension grows towards the left, etc. Otherwise, the
429
                coordinates will be in RAS+ orientation.
430
        """
431
        size = np.array(self.spatial_shape)
432
        center_index = (size - 1) / 2
433
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
434
        if lps:
435
            return (-r, -a, s)
436
        else:
437
            return (r, a, s)
438
439
    def set_check_nans(self, check_nans: bool):
440
        self.check_nans = check_nans
441
442
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
443
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
444
        new_affine = self.affine.copy()
445
        new_affine[:3, 3] = new_origin
446
        i0, j0, k0 = index_ini
447
        i1, j1, k1 = index_fin
448
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
449
        kwargs = dict(
450
            tensor=patch,
451
            affine=new_affine,
452
            type=self.type,
453
            path=self.path,
454
        )
455
        for key, value in self.items():
456
            if key in PROTECTED_KEYS: continue
457
            kwargs[key] = value  # should I copy? deepcopy?
458
        return self.__class__(**kwargs)
459
460
461
class ScalarImage(Image):
462
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
463
464
    Example:
465
        >>> import torch
466
        >>> import torchio
467
        >>> # Loading from a file
468
        >>> t1_image = torchio.ScalarImage('t1.nii.gz')
469
        >>> dmri = torchio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
470
        >>> image = torchio.ScalarImage('safe_image.nrrd', check_nans=False)
471
        >>> data, affine = image.data, image.affine
472
        >>> affine.shape
473
        (4, 4)
474
        >>> image.data is image[torchio.DATA]
475
        True
476
        >>> image.data is image.tensor
477
        True
478
        >>> type(image.data)
479
        torch.Tensor
480
481
    See :py:class:`~torchio.Image` for more information.
482
483
    Raises:
484
        ValueError: A :py:attr:`type` is used for instantiation.
485
    """
486
    def __init__(self, *args, **kwargs):
487
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
488
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
489
        kwargs.update({'type': INTENSITY})
490
        super().__init__(*args, **kwargs)
491
492
493
class LabelMap(Image):
494
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
495
496
    Example:
497
        >>> import torch
498
        >>> import torchio
499
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
500
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
501
        >>> tpm = torchio.LabelMap(                     # loading from files
502
        ...     'gray_matter.nii.gz',
503
        ...     'white_matter.nii.gz',
504
        ...     'csf.nii.gz',
505
        ... )
506
507
    See :py:class:`~torchio.data.image.Image` for more information.
508
509
    Raises:
510
        ValueError: If a value for :py:attr:`type` is given.
511
    """
512
    def __init__(self, *args, **kwargs):
513
        if 'type' in kwargs and kwargs['type'] != LABEL:
514
            raise ValueError('Type of LabelMap is always torchio.LABEL')
515
        kwargs.update({'type': LABEL})
516
        super().__init__(*args, **kwargs)
517