Passed
Push — master ( 85bce8...47d3da )
by Fernando
01:13
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 42
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 33
nop 8
dl 0
loc 42
rs 6.6666
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file or sequence of paths to files that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read.
48
            If a sequence of paths is given, data
49
            will be concatenated on the channel dimension so spatial
50
            dimensions must match.
51
        type: Type of image, such as :attr:`torchio.INTENSITY` or
52
            :attr:`torchio.LABEL`. This will be used by the transforms to
53
            decide whether to apply an operation, or which interpolation to use
54
            when resampling. For example, `preprocessing`_ and `augmentation`_
55
            intensity transforms will only be applied to images with type
56
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
57
            all types, and nearest neighbor interpolation is always used to
58
            resample images with type :attr:`torchio.LABEL`.
59
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
60
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
61
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
62
            :py:class:`torch.Tensor` or NumPy array with dimensions
63
            :math:`(C, W, H, D)`.
64
        affine: If :attr:`path` is not given, :attr:`affine` must be a
65
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
66
            identity matrix.
67
        check_nans: If ``True``, issues a warning if NaNs are found
68
            in the image. If ``False``, images will not be checked for the
69
            presence of NaNs.
70
        **kwargs: Items that will be added to the image dictionary, e.g.
71
            acquisition parameters.
72
73
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
74
    when needed.
75
76
    Example:
77
        >>> import torchio
78
        >>> image = torchio.ScalarImage('t1.nii.gz')  # subclass of Image
79
        >>> image  # not loaded yet
80
        ScalarImage(path: t1.nii.gz; type: intensity)
81
        >>> times_two = 2 * image.data  # data is loaded and cached here
82
        >>> image
83
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
84
        >>> image.save('doubled_image.nii.gz')
85
86
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
87
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
88
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
89
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
90
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
91
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
92
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
93
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
94
    """
95
    def __init__(
96
            self,
97
            path: Union[TypePath, Sequence[TypePath], None] = None,
98
            type: str = None,
99
            tensor: Optional[TypeData] = None,
100
            affine: Optional[TypeData] = None,
101
            check_nans: bool = False,  # removed by ITK by default
102
            channels_last: bool = False,
103
            **kwargs: Dict[str, Any],
104
            ):
105
        self.check_nans = check_nans
106
        self.channels_last = channels_last
107
108
        if type is None:
109
            warnings.warn(
110
                'Not specifying the image type is deprecated and will be'
111
                ' mandatory in the future. You can probably use ScalarImage or'
112
                ' LabelMap instead'
113
            )
114
            type = INTENSITY
115
116
        if path is None and tensor is None:
117
            raise ValueError('A value for path or tensor must be given')
118
        self._loaded = False
119
120
        tensor = self.parse_tensor(tensor)
121
        affine = self.parse_affine(affine)
122
        if tensor is not None:
123
            self[DATA] = tensor
124
            self[AFFINE] = affine
125
            self._loaded = True
126
        for key in PROTECTED_KEYS:
127
            if key in kwargs:
128
                message = f'Key "{key}" is reserved. Use a different one'
129
                raise ValueError(message)
130
131
        super().__init__(**kwargs)
132
        self.path = self._parse_path(path)
133
134
        self[PATH] = '' if self.path is None else str(self.path)
135
        self[STEM] = '' if self.path is None else get_stem(self.path)
136
        self[TYPE] = type
137
138
    def __repr__(self):
139
        properties = []
140
        if self._loaded:
141
            properties.extend([
142
                f'shape: {self.shape}',
143
                f'spacing: {self.get_spacing_string()}',
144
                f'orientation: {"".join(self.orientation)}+',
145
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
146
            ])
147
        else:
148
            properties.append(f'path: "{self.path}"')
149
        properties.append(f'type: {self.type}')
150
        properties = '; '.join(properties)
151
        string = f'{self.__class__.__name__}({properties})'
152
        return string
153
154
    def __getitem__(self, item):
155
        if item in (DATA, AFFINE):
156
            if item not in self:
157
                self.load()
158
        return super().__getitem__(item)
159
160
    def __array__(self):
161
        return self[DATA].numpy()
162
163
    def __copy__(self):
164
        kwargs = dict(
165
            tensor=self.data,
166
            affine=self.affine,
167
            type=self.type,
168
            path=self.path,
169
        )
170
        for key, value in self.items():
171
            if key in PROTECTED_KEYS: continue
172
            kwargs[key] = value  # should I copy? deepcopy?
173
        return self.__class__(**kwargs)
174
175
    @property
176
    def data(self):
177
        return self[DATA]
178
179
    @property
180
    def tensor(self):
181
        return self.data
182
183
    @property
184
    def affine(self):
185
        return self[AFFINE]
186
187
    @property
188
    def type(self):
189
        return self[TYPE]
190
191
    @property
192
    def shape(self) -> Tuple[int, int, int, int]:
193
        return tuple(self.data.shape)
194
195
    @property
196
    def spatial_shape(self) -> TypeTripletInt:
197
        return self.shape[1:]
198
199
    def check_is_2d(self):
200
        if not self.is_2d():
201
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
202
            raise RuntimeError(message)
203
204
    @property
205
    def height(self) -> int:
206
        self.check_is_2d()
207
        return self.spatial_shape[1]
208
209
    @property
210
    def width(self) -> int:
211
        self.check_is_2d()
212
        return self.spatial_shape[0]
213
214
    @property
215
    def orientation(self):
216
        return nib.aff2axcodes(self.affine)
217
218
    @property
219
    def spacing(self):
220
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
221
        return tuple(spacing)
222
223
    @property
224
    def memory(self):
225
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
226
227
    def axis_name_to_index(self, axis: str):
228
        """Convert an axis name to an axis index.
229
230
        Args:
231
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
232
            ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case versions
233
            and first letters are also valid, as only the first letter will be
234
            used.
235
236
        .. note:: If you are working with animals, you should probably use
237
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
238
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
239
            respectively.
240
241
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
242
            ``'Left'`` and ``'Right'``.
243
        """
244
        # Top and bottom are used for the vertical 2D axis as the use of
245
        # Height vs Horizontal might be ambiguous
246
247
        if not isinstance(axis, str):
248
            raise ValueError('Axis must be a string')
249
        axis = axis[0].upper()
250
251
        # Generally, TorchIO tensors are (C, W, H, D)
252
        if axis in 'TB':  # Top, Bottom
253
            return -2
254
        else:
255
            try:
256
                index = self.orientation.index(axis)
257
            except ValueError:
258
                index = self.orientation.index(self.flip_axis(axis))
259
            # Return negative indices so that it does not matter whether we
260
            # refer to spatial dimensions or not
261
            index = -3 + index
262
            return index
263
264
    # flake8: noqa: E701
265
    @staticmethod
266
    def flip_axis(axis):
267
        if axis == 'R': return 'L'
268
        elif axis == 'L': return 'R'
269
        elif axis == 'A': return 'P'
270
        elif axis == 'P': return 'A'
271
        elif axis == 'I': return 'S'
272
        elif axis == 'S': return 'I'
273
        else:
274
            values = ', '.join('LRPAISTB')
275
            message = f'Axis not understood. Please use one of: {values}'
276
            raise ValueError(message)
277
278
    def get_spacing_string(self):
279
        strings = [f'{n:.2f}' for n in self.spacing]
280
        string = f'({", ".join(strings)})'
281
        return string
282
283
    def get_bounds(self):
284
        """Get image bounds in mm."""
285
        first_index = 3 * (-0.5,)
286
        last_index = np.array(self.spatial_shape) - 0.5
287
        first_point = nib.affines.apply_affine(self.affine, first_index)
288
        last_point = nib.affines.apply_affine(self.affine, last_index)
289
        array = np.array((first_point, last_point))
290
        bounds_x, bounds_y, bounds_z = array.T.tolist()
291
        return bounds_x, bounds_y, bounds_z
292
293
    @staticmethod
294
    def _parse_single_path(
295
            path: TypePath
296
            ) -> Path:
297
        try:
298
            path = Path(path).expanduser()
299
        except TypeError:
300
            message = (
301
                f'Expected type str or Path but found {path} with '
302
                f'{type(path)} instead'
303
            )
304
            raise TypeError(message)
305
        except RuntimeError:
306
            message = (
307
                f'Conversion to path not possible for variable: {path}'
308
            )
309
            raise RuntimeError(message)
310
311
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
312
            raise FileNotFoundError(f'File not found: {path}')
313
        return path
314
315
    def _parse_path(
316
            self,
317
            path: Union[TypePath, Sequence[TypePath]]
318
            ) -> Union[Path, List[Path]]:
319
        if path is None:
320
            return None
321
        if isinstance(path, (str, Path)):
322
            return self._parse_single_path(path)
323
        else:
324
            return [self._parse_single_path(p) for p in path]
325
326
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
327
        if tensor is None:
328
            return None
329
        if isinstance(tensor, np.ndarray):
330
            tensor = torch.from_numpy(tensor.astype(np.float32))
331
        elif isinstance(tensor, torch.Tensor):
332
            tensor = tensor.float()
333
        if tensor.ndim != 4:
334
            raise ValueError('Input tensor must be 4D')
335
        if self.check_nans and torch.isnan(tensor).any():
336
            warnings.warn(f'NaNs found in tensor')
337
        return tensor
338
339
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
340
        return ensure_4d(tensor)
341
342
    @staticmethod
343
    def parse_affine(affine: np.ndarray) -> np.ndarray:
344
        if affine is None:
345
            return np.eye(4)
346
        if not isinstance(affine, np.ndarray):
347
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
348
        if affine.shape != (4, 4):
349
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
350
        return affine
351
352
    def load(self) -> None:
353
        r"""Load the image from disk.
354
355
        Returns:
356
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
357
            :math:`4 \times 4` affine matrix to convert voxel indices to world
358
            coordinates.
359
        """
360
        if self._loaded:
361
            return
362
        paths = self.path if isinstance(self.path, list) else [self.path]
363
        tensor, affine = self.read_and_check(paths[0])
364
        tensors = [tensor]
365
        for path in paths[1:]:
366
            new_tensor, new_affine = self.read_and_check(path)
367
            if not np.array_equal(affine, new_affine):
368
                message = (
369
                    'Files have different affine matrices.'
370
                    f'\nMatrix of {paths[0]}:'
371
                    f'\n{affine}'
372
                    f'\nMatrix of {path}:'
373
                    f'\n{new_affine}'
374
                )
375
                warnings.warn(message, RuntimeWarning)
376
            if not tensor.shape[1:] == new_tensor.shape[1:]:
377
                message = (
378
                    f'Files shape do not match, found {tensor.shape}'
379
                    f'and {new_tensor.shape}'
380
                )
381
                RuntimeError(message)
382
            tensors.append(new_tensor)
383
        tensor = torch.cat(tensors)
384
        self[DATA] = tensor
385
        self[AFFINE] = affine
386
        self._loaded = True
387
388
    def read_and_check(self, path):
389
        tensor, affine = read_image(path)
390
        tensor = self.parse_tensor_shape(tensor)
391
        if self.channels_last:
392
            tensor = tensor.permute(3, 0, 1, 2)
393
        if self.check_nans and torch.isnan(tensor).any():
394
            warnings.warn(f'NaNs found in file "{path}"')
395
        return tensor, affine
396
397
    def save(self, path: TypePath, squeeze: bool = True):
398
        """Save image to disk.
399
400
        Args:
401
            path: String or instance of :py:class:`pathlib.Path`.
402
            squeeze: If ``True``, the singleton dimensions will be removed
403
                before saving.
404
        """
405
        write_image(
406
            self[DATA],
407
            self[AFFINE],
408
            path,
409
            squeeze=squeeze,
410
        )
411
412
    def is_2d(self) -> bool:
413
        return self.shape[-1] == 1
414
415
    def numpy(self) -> np.ndarray:
416
        """Get a NumPy array containing the image data."""
417
        return np.asarray(self)
418
419
    def as_sitk(self, **kwargs) -> sitk.Image:
420
        """Get the image as an instance of :py:class:`sitk.Image`."""
421
        return nib_to_sitk(self[DATA], self[AFFINE], **kwargs)
422
423
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
424
        """Get image center in RAS+ or LPS+ coordinates.
425
426
        Args:
427
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
428
                the first dimension grows towards the left, etc. Otherwise, the
429
                coordinates will be in RAS+ orientation.
430
        """
431
        size = np.array(self.spatial_shape)
432
        center_index = (size - 1) / 2
433
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
434
        if lps:
435
            return (-r, -a, s)
436
        else:
437
            return (r, a, s)
438
439
    def set_check_nans(self, check_nans: bool):
440
        self.check_nans = check_nans
441
442
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
443
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
444
        new_affine = self.affine.copy()
445
        new_affine[:3, 3] = new_origin
446
        i0, j0, k0 = index_ini
447
        i1, j1, k1 = index_fin
448
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
449
        kwargs = dict(
450
            tensor=patch,
451
            affine=new_affine,
452
            type=self.type,
453
            path=self.path,
454
        )
455
        for key, value in self.items():
456
            if key in PROTECTED_KEYS: continue
457
            kwargs[key] = value  # should I copy? deepcopy?
458
        return self.__class__(**kwargs)
459
460
461
class ScalarImage(Image):
462
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
463
464
    Example:
465
        >>> import torch
466
        >>> import torchio
467
        >>> # Loading from a file
468
        >>> t1_image = torchio.ScalarImage('t1.nii.gz')
469
        >>> dmri = torchio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
470
        >>> image = torchio.ScalarImage('safe_image.nrrd', check_nans=False)
471
        >>> data, affine = image.data, image.affine
472
        >>> affine.shape
473
        (4, 4)
474
        >>> image.data is image[torchio.DATA]
475
        True
476
        >>> image.data is image.tensor
477
        True
478
        >>> type(image.data)
479
        torch.Tensor
480
481
    See :py:class:`~torchio.Image` for more information.
482
483
    Raises:
484
        ValueError: A :py:attr:`type` is used for instantiation.
485
    """
486
    def __init__(self, *args, **kwargs):
487
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
488
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
489
        kwargs.update({'type': INTENSITY})
490
        super().__init__(*args, **kwargs)
491
492
493
class LabelMap(Image):
494
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
495
496
    Example:
497
        >>> import torch
498
        >>> import torchio
499
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
500
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
501
        >>> tpm = torchio.LabelMap(                     # loading from files
502
        ...     'gray_matter.nii.gz',
503
        ...     'white_matter.nii.gz',
504
        ...     'csf.nii.gz',
505
        ... )
506
507
    See :py:class:`~torchio.data.image.Image` for more information.
508
509
    Raises:
510
        ValueError: If a value for :py:attr:`type` is given.
511
    """
512
    def __init__(self, *args, **kwargs):
513
        if 'type' in kwargs and kwargs['type'] != LABEL:
514
            raise ValueError('Type of LabelMap is always torchio.LABEL')
515
        kwargs.update({'type': LABEL})
516
        super().__init__(*args, **kwargs)
517