Passed
Pull Request — master (#332)
by Fernando
01:14
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 42
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 33
nop 8
dl 0
loc 42
rs 6.6666
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file or sequence of paths to files that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read.
48
            If a sequence of paths is given, data
49
            will be concatenated on the channel dimension so spatial
50
            dimensions must match.
51
        type: Type of image, such as :attr:`torchio.INTENSITY` or
52
            :attr:`torchio.LABEL`. This will be used by the transforms to
53
            decide whether to apply an operation, or which interpolation to use
54
            when resampling. For example, `preprocessing`_ and `augmentation`_
55
            intensity transforms will only be applied to images with type
56
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
57
            all types, and nearest neighbor interpolation is always used to
58
            resample images with type :attr:`torchio.LABEL`.
59
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
60
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
61
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
62
            :py:class:`torch.Tensor` or NumPy array with dimensions
63
            :math:`(C, W, H, D)`.
64
        affine: If :attr:`path` is not given, :attr:`affine` must be a
65
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
66
            identity matrix.
67
        check_nans: If ``True``, issues a warning if NaNs are found
68
            in the image. If ``False``, images will not be checked for the
69
            presence of NaNs.
70
        **kwargs: Items that will be added to the image dictionary, e.g.
71
            acquisition parameters.
72
73
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
74
    when needed.
75
76
    Example:
77
        >>> import torchio
78
        >>> image = torchio.ScalarImage('t1.nii.gz')  # subclass of Image
79
        >>> image  # not loaded yet
80
        ScalarImage(path: t1.nii.gz; type: intensity)
81
        >>> times_two = 2 * image.data  # data is loaded and cached here
82
        >>> image
83
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
84
        >>> image.save('doubled_image.nii.gz')
85
86
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
87
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
88
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
89
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
90
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
91
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
92
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
93
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
94
    """
95
    def __init__(
96
            self,
97
            path: Union[TypePath, Sequence[TypePath], None] = None,
98
            type: str = None,
99
            tensor: Optional[TypeData] = None,
100
            affine: Optional[TypeData] = None,
101
            check_nans: bool = False,  # removed by ITK by default
102
            channels_last: bool = False,
103
            **kwargs: Dict[str, Any],
104
            ):
105
        self.check_nans = check_nans
106
        self.channels_last = channels_last
107
108
        if type is None:
109
            warnings.warn(
110
                'Not specifying the image type is deprecated and will be'
111
                ' mandatory in the future. You can probably use ScalarImage or'
112
                ' LabelMap instead'
113
            )
114
            type = INTENSITY
115
116
        if path is None and tensor is None:
117
            raise ValueError('A value for path or tensor must be given')
118
        self._loaded = False
119
120
        tensor = self._parse_tensor(tensor)
121
        affine = self._parse_affine(affine)
122
        if tensor is not None:
123
            self[DATA] = tensor
124
            self[AFFINE] = affine
125
            self._loaded = True
126
        for key in PROTECTED_KEYS:
127
            if key in kwargs:
128
                message = f'Key "{key}" is reserved. Use a different one'
129
                raise ValueError(message)
130
131
        super().__init__(**kwargs)
132
        self.path = self._parse_path(path)
133
134
        self[PATH] = '' if self.path is None else str(self.path)
135
        self[STEM] = '' if self.path is None else get_stem(self.path)
136
        self[TYPE] = type
137
138
    def __repr__(self):
139
        properties = []
140
        if self._loaded:
141
            properties.extend([
142
                f'shape: {self.shape}',
143
                f'spacing: {self.get_spacing_string()}',
144
                f'orientation: {"".join(self.orientation)}+',
145
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
146
            ])
147
        else:
148
            properties.append(f'path: "{self.path}"')
149
        properties.append(f'type: {self.type}')
150
        properties = '; '.join(properties)
151
        string = f'{self.__class__.__name__}({properties})'
152
        return string
153
154
    def __getitem__(self, item):
155
        if item in (DATA, AFFINE):
156
            if item not in self:
157
                self.load()
158
        return super().__getitem__(item)
159
160
    def __array__(self):
161
        return self[DATA].numpy()
162
163
    def __copy__(self):
164
        kwargs = dict(
165
            tensor=self.data,
166
            affine=self.affine,
167
            type=self.type,
168
            path=self.path,
169
        )
170
        for key, value in self.items():
171
            if key in PROTECTED_KEYS: continue
172
            kwargs[key] = value  # should I copy? deepcopy?
173
        return self.__class__(**kwargs)
174
175
    @property
176
    def data(self) -> torch.Tensor:
177
        """Tensor data. Same as :py:class:`Image.tensor`."""
178
        return self[DATA]
179
180
    @property
181
    def tensor(self) -> torch.Tensor:
182
        """Tensor data. Same as :py:class:`Image.data`."""
183
        return self.data
184
185
    @property
186
    def affine(self) -> np.ndarray:
187
        """Affine matrix to transform voxel indices into world coordinates."""
188
        return self[AFFINE]
189
190
    @property
191
    def type(self) -> str:
192
        return self[TYPE]
193
194
    @property
195
    def shape(self) -> Tuple[int, int, int, int]:
196
        """Tensor shape as :math:`(C, W, H, D)`."""
197
        return tuple(self.data.shape)
198
199
    @property
200
    def spatial_shape(self) -> TypeTripletInt:
201
        """Tensor spatial shape as :math:`(W, H, D)`."""
202
        return self.shape[1:]
203
204
    def check_is_2d(self) -> None:
205
        if not self.is_2d():
206
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
207
            raise RuntimeError(message)
208
209
    @property
210
    def height(self) -> int:
211
        """Image height, if 2D."""
212
        self.check_is_2d()
213
        return self.spatial_shape[1]
214
215
    @property
216
    def width(self) -> int:
217
        """Image width, if 2D."""
218
        self.check_is_2d()
219
        return self.spatial_shape[0]
220
221
    @property
222
    def orientation(self) -> Tuple[str, str, str]:
223
        """Orientation codes."""
224
        return nib.aff2axcodes(self.affine)
225
226
    @property
227
    def spacing(self) -> Tuple[float, float, float]:
228
        """Voxel spacing in mm."""
229
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
230
        return tuple(spacing)
231
232
    @property
233
    def memory(self) -> float:
234
        """Number of Bytes that the tensor takes in the RAM."""
235
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
236
237
    @property
238
    def bounds(self) -> np.ndarray:
239
        """Position of centers of voxels in smallest and largest coordinates."""
240
        ini = 0, 0, 0
241
        fin = np.array(self.spatial_shape) - 1
242
        point_ini = nib.affines.apply_affine(self.affine, ini)
243
        point_fin = nib.affines.apply_affine(self.affine, fin)
244
        return np.array((point_ini, point_fin))
245
246
    def axis_name_to_index(self, axis: str):
247
        """Convert an axis name to an axis index.
248
249
        Args:
250
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
251
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
252
                versions and first letters are also valid, as only the first
253
                letter will be used.
254
255
        .. note:: If you are working with animals, you should probably use
256
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
257
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
258
            respectively.
259
260
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
261
            ``'Left'`` and ``'Right'``.
262
        """
263
        # Top and bottom are used for the vertical 2D axis as the use of
264
        # Height vs Horizontal might be ambiguous
265
266
        if not isinstance(axis, str):
267
            raise ValueError('Axis must be a string')
268
        axis = axis[0].upper()
269
270
        # Generally, TorchIO tensors are (C, W, H, D)
271
        if axis in 'TB':  # Top, Bottom
272
            return -2
273
        else:
274
            try:
275
                index = self.orientation.index(axis)
276
            except ValueError:
277
                index = self.orientation.index(self.flip_axis(axis))
278
            # Return negative indices so that it does not matter whether we
279
            # refer to spatial dimensions or not
280
            index = -3 + index
281
            return index
282
283
    # flake8: noqa: E701
284
    @staticmethod
285
    def flip_axis(axis):
286
        if axis == 'R': return 'L'
287
        elif axis == 'L': return 'R'
288
        elif axis == 'A': return 'P'
289
        elif axis == 'P': return 'A'
290
        elif axis == 'I': return 'S'
291
        elif axis == 'S': return 'I'
292
        else:
293
            values = ', '.join('LRPAISTB')
294
            message = f'Axis not understood. Please use one of: {values}'
295
            raise ValueError(message)
296
297
    def get_spacing_string(self):
298
        strings = [f'{n:.2f}' for n in self.spacing]
299
        string = f'({", ".join(strings)})'
300
        return string
301
302
    def get_bounds(self):
303
        """Get image bounds in mm."""
304
        first_index = 3 * (-0.5,)
305
        last_index = np.array(self.spatial_shape) - 0.5
306
        first_point = nib.affines.apply_affine(self.affine, first_index)
307
        last_point = nib.affines.apply_affine(self.affine, last_index)
308
        array = np.array((first_point, last_point))
309
        bounds_x, bounds_y, bounds_z = array.T.tolist()
310
        return bounds_x, bounds_y, bounds_z
311
312
    @staticmethod
313
    def _parse_single_path(
314
            path: TypePath
315
            ) -> Path:
316
        try:
317
            path = Path(path).expanduser()
318
        except TypeError:
319
            message = (
320
                f'Expected type str or Path but found {path} with '
321
                f'{type(path)} instead'
322
            )
323
            raise TypeError(message)
324
        except RuntimeError:
325
            message = (
326
                f'Conversion to path not possible for variable: {path}'
327
            )
328
            raise RuntimeError(message)
329
330
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
331
            raise FileNotFoundError(f'File not found: "{path}"')
332
        return path
333
334
    def _parse_path(
335
            self,
336
            path: Union[TypePath, Sequence[TypePath]]
337
            ) -> Union[Path, List[Path]]:
338
        if path is None:
339
            return None
340
        if isinstance(path, (str, Path)):
341
            return self._parse_single_path(path)
342
        else:
343
            return [self._parse_single_path(p) for p in path]
344
345
    def _parse_tensor(self, tensor: TypeData) -> torch.Tensor:
346
        if tensor is None:
347
            return None
348
        if isinstance(tensor, np.ndarray):
349
            tensor = torch.from_numpy(tensor.astype(np.float32))
350
        elif isinstance(tensor, torch.Tensor):
351
            tensor = tensor.float()
352
        if tensor.ndim != 4:
353
            raise ValueError('Input tensor must be 4D')
354
        if self.check_nans and torch.isnan(tensor).any():
355
            warnings.warn(f'NaNs found in tensor')
356
        return tensor
357
358
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
359
        return ensure_4d(tensor)
360
361
    @staticmethod
362
    def _parse_affine(affine: np.ndarray) -> np.ndarray:
363
        if affine is None:
364
            return np.eye(4)
365
        if not isinstance(affine, np.ndarray):
366
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
367
        if affine.shape != (4, 4):
368
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
369
        return affine
370
371
    def load(self) -> None:
372
        r"""Load the image from disk.
373
374
        Returns:
375
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
376
            :math:`4 \times 4` affine matrix to convert voxel indices to world
377
            coordinates.
378
        """
379
        if self._loaded:
380
            return
381
        paths = self.path if isinstance(self.path, list) else [self.path]
382
        tensor, affine = self.read_and_check(paths[0])
383
        tensors = [tensor]
384
        for path in paths[1:]:
385
            new_tensor, new_affine = self.read_and_check(path)
386
            if not np.array_equal(affine, new_affine):
387
                message = (
388
                    'Files have different affine matrices.'
389
                    f'\nMatrix of {paths[0]}:'
390
                    f'\n{affine}'
391
                    f'\nMatrix of {path}:'
392
                    f'\n{new_affine}'
393
                )
394
                warnings.warn(message, RuntimeWarning)
395
            if not tensor.shape[1:] == new_tensor.shape[1:]:
396
                message = (
397
                    f'Files shape do not match, found {tensor.shape}'
398
                    f'and {new_tensor.shape}'
399
                )
400
                RuntimeError(message)
401
            tensors.append(new_tensor)
402
        tensor = torch.cat(tensors)
403
        self[DATA] = tensor
404
        self[AFFINE] = affine
405
        self._loaded = True
406
407
    def read_and_check(self, path):
408
        tensor, affine = read_image(path)
409
        tensor = self.parse_tensor_shape(tensor)
410
        if self.channels_last:
411
            tensor = tensor.permute(3, 0, 1, 2)
412
        if self.check_nans and torch.isnan(tensor).any():
413
            warnings.warn(f'NaNs found in file "{path}"')
414
        return tensor, affine
415
416
    def save(self, path: TypePath, squeeze: bool = True):
417
        """Save image to disk.
418
419
        Args:
420
            path: String or instance of :py:class:`pathlib.Path`.
421
            squeeze: If ``True``, the singleton dimensions will be removed
422
                before saving.
423
        """
424
        write_image(
425
            self[DATA],
426
            self[AFFINE],
427
            path,
428
            squeeze=squeeze,
429
        )
430
431
    def is_2d(self) -> bool:
432
        return self.shape[-1] == 1
433
434
    def numpy(self) -> np.ndarray:
435
        """Get a NumPy array containing the image data."""
436
        return np.asarray(self)
437
438
    def as_sitk(self, **kwargs) -> sitk.Image:
439
        """Get the image as an instance of :py:class:`sitk.Image`."""
440
        return nib_to_sitk(self[DATA], self[AFFINE], **kwargs)
441
442
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
443
        """Get image center in RAS+ or LPS+ coordinates.
444
445
        Args:
446
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
447
                the first dimension grows towards the left, etc. Otherwise, the
448
                coordinates will be in RAS+ orientation.
449
        """
450
        size = np.array(self.spatial_shape)
451
        center_index = (size - 1) / 2
452
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
453
        if lps:
454
            return (-r, -a, s)
455
        else:
456
            return (r, a, s)
457
458
    def set_check_nans(self, check_nans: bool):
459
        self.check_nans = check_nans
460
461
    def plot(self, **kwargs) -> None:
462
        """Plot central slices of the image."""
463
        from ..visualization import plot_volume  # avoid circular import
464
        plot_volume(self, **kwargs)
465
466
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
467
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
468
        new_affine = self.affine.copy()
469
        new_affine[:3, 3] = new_origin
470
        i0, j0, k0 = index_ini
471
        i1, j1, k1 = index_fin
472
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
473
        kwargs = dict(
474
            tensor=patch,
475
            affine=new_affine,
476
            type=self.type,
477
            path=self.path,
478
        )
479
        for key, value in self.items():
480
            if key in PROTECTED_KEYS: continue
481
            kwargs[key] = value  # should I copy? deepcopy?
482
        return self.__class__(**kwargs)
483
484
485
class ScalarImage(Image):
486
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
487
488
    Example:
489
        >>> import torch
490
        >>> import torchio
491
        >>> # Loading from a file
492
        >>> t1_image = torchio.ScalarImage('t1.nii.gz')
493
        >>> dmri = torchio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
494
        >>> image = torchio.ScalarImage('safe_image.nrrd', check_nans=False)
495
        >>> data, affine = image.data, image.affine
496
        >>> affine.shape
497
        (4, 4)
498
        >>> image.data is image[torchio.DATA]
499
        True
500
        >>> image.data is image.tensor
501
        True
502
        >>> type(image.data)
503
        torch.Tensor
504
505
    See :py:class:`~torchio.Image` for more information.
506
507
    Raises:
508
        ValueError: A :py:attr:`type` is used for instantiation.
509
    """
510
    def __init__(self, *args, **kwargs):
511
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
512
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
513
        kwargs.update({'type': INTENSITY})
514
        super().__init__(*args, **kwargs)
515
516
517
class LabelMap(Image):
518
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
519
520
    Example:
521
        >>> import torch
522
        >>> import torchio
523
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
524
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
525
        >>> tpm = torchio.LabelMap(                     # loading from files
526
        ...     'gray_matter.nii.gz',
527
        ...     'white_matter.nii.gz',
528
        ...     'csf.nii.gz',
529
        ... )
530
531
    See :py:class:`~torchio.data.image.Image` for more information.
532
533
    Raises:
534
        ValueError: If a value for :py:attr:`type` is given.
535
    """
536
    def __init__(self, *args, **kwargs):
537
        if 'type' in kwargs and kwargs['type'] != LABEL:
538
            raise ValueError('Type of LabelMap is always torchio.LABEL')
539
        kwargs.update({'type': LABEL})
540
        super().__init__(*args, **kwargs)
541