Passed
Pull Request — master (#334)
by Fernando
01:13
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 42
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 33
nop 8
dl 0
loc 42
rs 6.6666
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
from PIL import Image as ImagePIL
9
import nibabel as nib
10
import SimpleITK as sitk
11
12
from ..utils import (
13
    nib_to_sitk,
14
    get_rotation_and_spacing_from_affine,
15
    get_stem,
16
    ensure_4d,
17
)
18
from ..torchio import (
19
    TypeData,
20
    TypePath,
21
    TypeTripletInt,
22
    TypeTripletFloat,
23
    DATA,
24
    TYPE,
25
    AFFINE,
26
    PATH,
27
    STEM,
28
    INTENSITY,
29
    LABEL,
30
)
31
from .io import read_image, write_image
32
33
34
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
35
36
37
class Image(dict):
38
    r"""TorchIO image.
39
40
    For information about medical image orientation, check out `NiBabel docs`_,
41
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
42
    `SimpleITK docs`_.
43
44
    Args:
45
        path: Path to a file or sequence of paths to files that can be read by
46
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
47
            DICOM files. If :py:attr:`tensor` is given, the data in
48
            :py:attr:`path` will not be read.
49
            If a sequence of paths is given, data
50
            will be concatenated on the channel dimension so spatial
51
            dimensions must match.
52
        type: Type of image, such as :attr:`torchio.INTENSITY` or
53
            :attr:`torchio.LABEL`. This will be used by the transforms to
54
            decide whether to apply an operation, or which interpolation to use
55
            when resampling. For example, `preprocessing`_ and `augmentation`_
56
            intensity transforms will only be applied to images with type
57
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
58
            all types, and nearest neighbor interpolation is always used to
59
            resample images with type :attr:`torchio.LABEL`.
60
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
61
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
62
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
63
            :py:class:`torch.Tensor` or NumPy array with dimensions
64
            :math:`(C, W, H, D)`.
65
        affine: If :attr:`path` is not given, :attr:`affine` must be a
66
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
67
            identity matrix.
68
        check_nans: If ``True``, issues a warning if NaNs are found
69
            in the image. If ``False``, images will not be checked for the
70
            presence of NaNs.
71
        **kwargs: Items that will be added to the image dictionary, e.g.
72
            acquisition parameters.
73
74
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
75
    when needed.
76
77
    Example:
78
        >>> import torchio
79
        >>> image = torchio.ScalarImage('t1.nii.gz')  # subclass of Image
80
        >>> image  # not loaded yet
81
        ScalarImage(path: t1.nii.gz; type: intensity)
82
        >>> times_two = 2 * image.data  # data is loaded and cached here
83
        >>> image
84
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
85
        >>> image.save('doubled_image.nii.gz')
86
87
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
88
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
89
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
90
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
91
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
92
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
93
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
94
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
95
    """
96
    def __init__(
97
            self,
98
            path: Union[TypePath, Sequence[TypePath], None] = None,
99
            type: str = None,
100
            tensor: Optional[TypeData] = None,
101
            affine: Optional[TypeData] = None,
102
            check_nans: bool = False,  # removed by ITK by default
103
            channels_last: bool = False,
104
            **kwargs: Dict[str, Any],
105
            ):
106
        self.check_nans = check_nans
107
        self.channels_last = channels_last
108
109
        if type is None:
110
            warnings.warn(
111
                'Not specifying the image type is deprecated and will be'
112
                ' mandatory in the future. You can probably use ScalarImage or'
113
                ' LabelMap instead'
114
            )
115
            type = INTENSITY
116
117
        if path is None and tensor is None:
118
            raise ValueError('A value for path or tensor must be given')
119
        self._loaded = False
120
121
        tensor = self._parse_tensor(tensor)
122
        affine = self._parse_affine(affine)
123
        if tensor is not None:
124
            self[DATA] = tensor
125
            self[AFFINE] = affine
126
            self._loaded = True
127
        for key in PROTECTED_KEYS:
128
            if key in kwargs:
129
                message = f'Key "{key}" is reserved. Use a different one'
130
                raise ValueError(message)
131
132
        super().__init__(**kwargs)
133
        self.path = self._parse_path(path)
134
135
        self[PATH] = '' if self.path is None else str(self.path)
136
        self[STEM] = '' if self.path is None else get_stem(self.path)
137
        self[TYPE] = type
138
139
    def __repr__(self):
140
        properties = []
141
        if self._loaded:
142
            properties.extend([
143
                f'shape: {self.shape}',
144
                f'spacing: {self.get_spacing_string()}',
145
                f'orientation: {"".join(self.orientation)}+',
146
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
147
            ])
148
        else:
149
            properties.append(f'path: "{self.path}"')
150
        properties.append(f'type: {self.type}')
151
        properties = '; '.join(properties)
152
        string = f'{self.__class__.__name__}({properties})'
153
        return string
154
155
    def __getitem__(self, item):
156
        if item in (DATA, AFFINE):
157
            if item not in self:
158
                self.load()
159
        return super().__getitem__(item)
160
161
    def __array__(self):
162
        return self[DATA].numpy()
163
164
    def __copy__(self):
165
        kwargs = dict(
166
            tensor=self.data,
167
            affine=self.affine,
168
            type=self.type,
169
            path=self.path,
170
        )
171
        for key, value in self.items():
172
            if key in PROTECTED_KEYS: continue
173
            kwargs[key] = value  # should I copy? deepcopy?
174
        return self.__class__(**kwargs)
175
176
    @property
177
    def data(self) -> torch.Tensor:
178
        """Tensor data. Same as :py:class:`Image.tensor`."""
179
        return self[DATA]
180
181
    @property
182
    def tensor(self) -> torch.Tensor:
183
        """Tensor data. Same as :py:class:`Image.data`."""
184
        return self.data
185
186
    @property
187
    def affine(self) -> np.ndarray:
188
        """Affine matrix to transform voxel indices into world coordinates."""
189
        return self[AFFINE]
190
191
    @property
192
    def type(self) -> str:
193
        return self[TYPE]
194
195
    @property
196
    def shape(self) -> Tuple[int, int, int, int]:
197
        """Tensor shape as :math:`(C, W, H, D)`."""
198
        return tuple(self.data.shape)
199
200
    @property
201
    def spatial_shape(self) -> TypeTripletInt:
202
        """Tensor spatial shape as :math:`(W, H, D)`."""
203
        return self.shape[1:]
204
205
    def check_is_2d(self) -> None:
206
        if not self.is_2d():
207
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
208
            raise RuntimeError(message)
209
210
    @property
211
    def height(self) -> int:
212
        """Image height, if 2D."""
213
        self.check_is_2d()
214
        return self.spatial_shape[1]
215
216
    @property
217
    def width(self) -> int:
218
        """Image width, if 2D."""
219
        self.check_is_2d()
220
        return self.spatial_shape[0]
221
222
    @property
223
    def orientation(self) -> Tuple[str, str, str]:
224
        """Orientation codes."""
225
        return nib.aff2axcodes(self.affine)
226
227
    @property
228
    def spacing(self) -> Tuple[float, float, float]:
229
        """Voxel spacing in mm."""
230
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
231
        return tuple(spacing)
232
233
    @property
234
    def memory(self) -> float:
235
        """Number of Bytes that the tensor takes in the RAM."""
236
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
237
238
    @property
239
    def bounds(self) -> np.ndarray:
240
        """Position of centers of voxels in smallest and largest coordinates."""
241
        ini = 0, 0, 0
242
        fin = np.array(self.spatial_shape) - 1
243
        point_ini = nib.affines.apply_affine(self.affine, ini)
244
        point_fin = nib.affines.apply_affine(self.affine, fin)
245
        return np.array((point_ini, point_fin))
246
247
    def axis_name_to_index(self, axis: str):
248
        """Convert an axis name to an axis index.
249
250
        Args:
251
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
252
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
253
                versions and first letters are also valid, as only the first
254
                letter will be used.
255
256
        .. note:: If you are working with animals, you should probably use
257
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
258
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
259
            respectively.
260
261
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
262
            ``'Left'`` and ``'Right'``.
263
        """
264
        # Top and bottom are used for the vertical 2D axis as the use of
265
        # Height vs Horizontal might be ambiguous
266
267
        if not isinstance(axis, str):
268
            raise ValueError('Axis must be a string')
269
        axis = axis[0].upper()
270
271
        # Generally, TorchIO tensors are (C, W, H, D)
272
        if axis in 'TB':  # Top, Bottom
273
            return -2
274
        else:
275
            try:
276
                index = self.orientation.index(axis)
277
            except ValueError:
278
                index = self.orientation.index(self.flip_axis(axis))
279
            # Return negative indices so that it does not matter whether we
280
            # refer to spatial dimensions or not
281
            index = -3 + index
282
            return index
283
284
    # flake8: noqa: E701
285
    @staticmethod
286
    def flip_axis(axis):
287
        if axis == 'R': return 'L'
288
        elif axis == 'L': return 'R'
289
        elif axis == 'A': return 'P'
290
        elif axis == 'P': return 'A'
291
        elif axis == 'I': return 'S'
292
        elif axis == 'S': return 'I'
293
        else:
294
            values = ', '.join('LRPAISTB')
295
            message = f'Axis not understood. Please use one of: {values}'
296
            raise ValueError(message)
297
298
    def get_spacing_string(self):
299
        strings = [f'{n:.2f}' for n in self.spacing]
300
        string = f'({", ".join(strings)})'
301
        return string
302
303
    def get_bounds(self):
304
        """Get image bounds in mm."""
305
        first_index = 3 * (-0.5,)
306
        last_index = np.array(self.spatial_shape) - 0.5
307
        first_point = nib.affines.apply_affine(self.affine, first_index)
308
        last_point = nib.affines.apply_affine(self.affine, last_index)
309
        array = np.array((first_point, last_point))
310
        bounds_x, bounds_y, bounds_z = array.T.tolist()
311
        return bounds_x, bounds_y, bounds_z
312
313
    @staticmethod
314
    def _parse_single_path(
315
            path: TypePath
316
            ) -> Path:
317
        try:
318
            path = Path(path).expanduser()
319
        except TypeError:
320
            message = (
321
                f'Expected type str or Path but found {path} with '
322
                f'{type(path)} instead'
323
            )
324
            raise TypeError(message)
325
        except RuntimeError:
326
            message = (
327
                f'Conversion to path not possible for variable: {path}'
328
            )
329
            raise RuntimeError(message)
330
331
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
332
            raise FileNotFoundError(f'File not found: "{path}"')
333
        return path
334
335
    def _parse_path(
336
            self,
337
            path: Union[TypePath, Sequence[TypePath]]
338
            ) -> Union[Path, List[Path]]:
339
        if path is None:
340
            return None
341
        if isinstance(path, (str, Path)):
342
            return self._parse_single_path(path)
343
        else:
344
            return [self._parse_single_path(p) for p in path]
345
346
    def _parse_tensor(self, tensor: TypeData) -> torch.Tensor:
347
        if tensor is None:
348
            return None
349
        if isinstance(tensor, np.ndarray):
350
            tensor = torch.from_numpy(tensor.astype(np.float32))
351
        elif isinstance(tensor, torch.Tensor):
352
            tensor = tensor.float()
353
        if tensor.ndim != 4:
354
            raise ValueError('Input tensor must be 4D')
355
        if self.check_nans and torch.isnan(tensor).any():
356
            warnings.warn(f'NaNs found in tensor')
357
        return tensor
358
359
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
360
        return ensure_4d(tensor)
361
362
    @staticmethod
363
    def _parse_affine(affine: np.ndarray) -> np.ndarray:
364
        if affine is None:
365
            return np.eye(4)
366
        if not isinstance(affine, np.ndarray):
367
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
368
        if affine.shape != (4, 4):
369
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
370
        return affine
371
372
    def load(self) -> None:
373
        r"""Load the image from disk.
374
375
        Returns:
376
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
377
            :math:`4 \times 4` affine matrix to convert voxel indices to world
378
            coordinates.
379
        """
380
        if self._loaded:
381
            return
382
        paths = self.path if isinstance(self.path, list) else [self.path]
383
        tensor, affine = self.read_and_check(paths[0])
384
        tensors = [tensor]
385
        for path in paths[1:]:
386
            new_tensor, new_affine = self.read_and_check(path)
387
            if not np.array_equal(affine, new_affine):
388
                message = (
389
                    'Files have different affine matrices.'
390
                    f'\nMatrix of {paths[0]}:'
391
                    f'\n{affine}'
392
                    f'\nMatrix of {path}:'
393
                    f'\n{new_affine}'
394
                )
395
                warnings.warn(message, RuntimeWarning)
396
            if not tensor.shape[1:] == new_tensor.shape[1:]:
397
                message = (
398
                    f'Files shape do not match, found {tensor.shape}'
399
                    f'and {new_tensor.shape}'
400
                )
401
                RuntimeError(message)
402
            tensors.append(new_tensor)
403
        tensor = torch.cat(tensors)
404
        self[DATA] = tensor
405
        self[AFFINE] = affine
406
        self._loaded = True
407
408
    def read_and_check(self, path):
409
        tensor, affine = read_image(path)
410
        tensor = self.parse_tensor_shape(tensor)
411
        if self.channels_last:
412
            tensor = tensor.permute(3, 0, 1, 2)
413
        if self.check_nans and torch.isnan(tensor).any():
414
            warnings.warn(f'NaNs found in file "{path}"')
415
        return tensor, affine
416
417
    def save(self, path: TypePath, squeeze: bool = True):
418
        """Save image to disk.
419
420
        Args:
421
            path: String or instance of :py:class:`pathlib.Path`.
422
            squeeze: If ``True``, the singleton dimensions will be removed
423
                before saving.
424
        """
425
        write_image(
426
            self[DATA],
427
            self[AFFINE],
428
            path,
429
            squeeze=squeeze,
430
        )
431
432
    def is_2d(self) -> bool:
433
        return self.shape[-1] == 1
434
435
    def numpy(self) -> np.ndarray:
436
        """Get a NumPy array containing the image data."""
437
        return np.asarray(self)
438
439
    def as_sitk(self, **kwargs) -> sitk.Image:
440
        """Get the image as an instance of :py:class:`sitk.Image`."""
441
        return nib_to_sitk(self[DATA], self[AFFINE], **kwargs)
442
443
    def as_pil(self):
444
        """Get the image as an instance of :py:class:`PIL.Image`."""
445
        self.check_is_2d()
446
        return ImagePIL.open(self.path)
447
448
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
449
        """Get image center in RAS+ or LPS+ coordinates.
450
451
        Args:
452
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
453
                the first dimension grows towards the left, etc. Otherwise, the
454
                coordinates will be in RAS+ orientation.
455
        """
456
        size = np.array(self.spatial_shape)
457
        center_index = (size - 1) / 2
458
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
459
        if lps:
460
            return (-r, -a, s)
461
        else:
462
            return (r, a, s)
463
464
    def set_check_nans(self, check_nans: bool):
465
        self.check_nans = check_nans
466
467
    def plot(self, **kwargs) -> None:
468
        """Plot image."""
469
        if self.is_2d():
470
            self.as_pil().show()
471
        else:
472
            from ..visualization import plot_volume  # avoid circular import
473
            plot_volume(self, **kwargs)
474
475
476
class ScalarImage(Image):
477
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
478
479
    Example:
480
        >>> import torch
481
        >>> import torchio
482
        >>> # Loading from a file
483
        >>> t1_image = torchio.ScalarImage('t1.nii.gz')
484
        >>> dmri = torchio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
485
        >>> image = torchio.ScalarImage('safe_image.nrrd', check_nans=False)
486
        >>> data, affine = image.data, image.affine
487
        >>> affine.shape
488
        (4, 4)
489
        >>> image.data is image[torchio.DATA]
490
        True
491
        >>> image.data is image.tensor
492
        True
493
        >>> type(image.data)
494
        torch.Tensor
495
496
    See :py:class:`~torchio.Image` for more information.
497
498
    Raises:
499
        ValueError: A :py:attr:`type` is used for instantiation.
500
    """
501
    def __init__(self, *args, **kwargs):
502
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
503
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
504
        kwargs.update({'type': INTENSITY})
505
        super().__init__(*args, **kwargs)
506
507
508
class LabelMap(Image):
509
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
510
511
    Example:
512
        >>> import torch
513
        >>> import torchio
514
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
515
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
516
        >>> tpm = torchio.LabelMap(                     # loading from files
517
        ...     'gray_matter.nii.gz',
518
        ...     'white_matter.nii.gz',
519
        ...     'csf.nii.gz',
520
        ... )
521
522
    See :py:class:`~torchio.data.image.Image` for more information.
523
524
    Raises:
525
        ValueError: If a value for :py:attr:`type` is given.
526
    """
527
    def __init__(self, *args, **kwargs):
528
        if 'type' in kwargs and kwargs['type'] != LABEL:
529
            raise ValueError('Type of LabelMap is always torchio.LABEL')
530
        kwargs.update({'type': LABEL})
531
        super().__init__(*args, **kwargs)
532