Passed
Pull Request — master (#332)
by Fernando
03:27 queued 02:14
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 42
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 33
nop 8
dl 0
loc 42
rs 6.6666
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file or sequence of paths to files that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read.
48
            If a sequence of paths is given, data
49
            will be concatenated on the channel dimension so spatial
50
            dimensions must match.
51
        type: Type of image, such as :attr:`torchio.INTENSITY` or
52
            :attr:`torchio.LABEL`. This will be used by the transforms to
53
            decide whether to apply an operation, or which interpolation to use
54
            when resampling. For example, `preprocessing`_ and `augmentation`_
55
            intensity transforms will only be applied to images with type
56
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
57
            all types, and nearest neighbor interpolation is always used to
58
            resample images with type :attr:`torchio.LABEL`.
59
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
60
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
61
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
62
            :py:class:`torch.Tensor` or NumPy array with dimensions
63
            :math:`(C, W, H, D)`.
64
        affine: If :attr:`path` is not given, :attr:`affine` must be a
65
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
66
            identity matrix.
67
        check_nans: If ``True``, issues a warning if NaNs are found
68
            in the image. If ``False``, images will not be checked for the
69
            presence of NaNs.
70
        **kwargs: Items that will be added to the image dictionary, e.g.
71
            acquisition parameters.
72
73
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
74
    when needed.
75
76
    Example:
77
        >>> import torchio
78
        >>> image = torchio.ScalarImage('t1.nii.gz')  # subclass of Image
79
        >>> image  # not loaded yet
80
        ScalarImage(path: t1.nii.gz; type: intensity)
81
        >>> times_two = 2 * image.data  # data is loaded and cached here
82
        >>> image
83
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
84
        >>> image.save('doubled_image.nii.gz')
85
86
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
87
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
88
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
89
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
90
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
91
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
92
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
93
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
94
    """
95
    def __init__(
96
            self,
97
            path: Union[TypePath, Sequence[TypePath], None] = None,
98
            type: str = None,
99
            tensor: Optional[TypeData] = None,
100
            affine: Optional[TypeData] = None,
101
            check_nans: bool = False,  # removed by ITK by default
102
            channels_last: bool = False,
103
            **kwargs: Dict[str, Any],
104
            ):
105
        self.check_nans = check_nans
106
        self.channels_last = channels_last
107
108
        if type is None:
109
            warnings.warn(
110
                'Not specifying the image type is deprecated and will be'
111
                ' mandatory in the future. You can probably use ScalarImage or'
112
                ' LabelMap instead'
113
            )
114
            type = INTENSITY
115
116
        if path is None and tensor is None:
117
            raise ValueError('A value for path or tensor must be given')
118
        self._loaded = False
119
120
        tensor = self.parse_tensor(tensor)
121
        affine = self.parse_affine(affine)
122
        if tensor is not None:
123
            self[DATA] = tensor
124
            self[AFFINE] = affine
125
            self._loaded = True
126
        for key in PROTECTED_KEYS:
127
            if key in kwargs:
128
                message = f'Key "{key}" is reserved. Use a different one'
129
                raise ValueError(message)
130
131
        super().__init__(**kwargs)
132
        self.path = self._parse_path(path)
133
134
        self[PATH] = '' if self.path is None else str(self.path)
135
        self[STEM] = '' if self.path is None else get_stem(self.path)
136
        self[TYPE] = type
137
138
    def __repr__(self):
139
        properties = []
140
        if self._loaded:
141
            properties.extend([
142
                f'shape: {self.shape}',
143
                f'spacing: {self.get_spacing_string()}',
144
                f'orientation: {"".join(self.orientation)}+',
145
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
146
            ])
147
        else:
148
            properties.append(f'path: "{self.path}"')
149
        properties.append(f'type: {self.type}')
150
        properties = '; '.join(properties)
151
        string = f'{self.__class__.__name__}({properties})'
152
        return string
153
154
    def __getitem__(self, item):
155
        if item in (DATA, AFFINE):
156
            if item not in self:
157
                self.load()
158
        return super().__getitem__(item)
159
160
    def __array__(self):
161
        return self[DATA].numpy()
162
163
    def __copy__(self):
164
        kwargs = dict(
165
            tensor=self.data,
166
            affine=self.affine,
167
            type=self.type,
168
            path=self.path,
169
        )
170
        for key, value in self.items():
171
            if key in PROTECTED_KEYS: continue
172
            kwargs[key] = value  # should I copy? deepcopy?
173
        return self.__class__(**kwargs)
174
175
    @property
176
    def data(self) -> torch.Tensor:
177
        return self[DATA]
178
179
    @property
180
    def tensor(self) -> torch.Tensor:
181
        return self.data
182
183
    @property
184
    def affine(self) -> np.ndarray:
185
        return self[AFFINE]
186
187
    @property
188
    def type(self) -> str:
189
        return self[TYPE]
190
191
    @property
192
    def shape(self) -> Tuple[int, int, int, int]:
193
        return tuple(self.data.shape)
194
195
    @property
196
    def spatial_shape(self) -> TypeTripletInt:
197
        return self.shape[1:]
198
199
    def check_is_2d(self) -> None:
200
        if not self.is_2d():
201
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
202
            raise RuntimeError(message)
203
204
    @property
205
    def height(self) -> int:
206
        self.check_is_2d()
207
        return self.spatial_shape[1]
208
209
    @property
210
    def width(self) -> int:
211
        self.check_is_2d()
212
        return self.spatial_shape[0]
213
214
    @property
215
    def orientation(self) -> Tuple[str, str, str]:
216
        return nib.aff2axcodes(self.affine)
217
218
    @property
219
    def spacing(self) -> Tuple[float, float, float]:
220
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
221
        return tuple(spacing)
222
223
    @property
224
    def memory(self) -> float:
225
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
226
227
    @property
228
    def bounds(self) -> np.ndarray:
229
        ini = 0, 0, 0
230
        fin = np.array(self.spatial_shape) - 1
231
        point_ini = nib.affines.apply_affine(self.affine, ini)
232
        point_fin = nib.affines.apply_affine(self.affine, fin)
233
        return np.array((point_ini, point_fin))
234
235
    def axis_name_to_index(self, axis: str):
236
        """Convert an axis name to an axis index.
237
238
        Args:
239
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
240
            ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case versions
241
            and first letters are also valid, as only the first letter will be
242
            used.
243
244
        .. note:: If you are working with animals, you should probably use
245
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
246
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
247
            respectively.
248
249
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
250
            ``'Left'`` and ``'Right'``.
251
        """
252
        # Top and bottom are used for the vertical 2D axis as the use of
253
        # Height vs Horizontal might be ambiguous
254
255
        if not isinstance(axis, str):
256
            raise ValueError('Axis must be a string')
257
        axis = axis[0].upper()
258
259
        # Generally, TorchIO tensors are (C, W, H, D)
260
        if axis in 'TB':  # Top, Bottom
261
            return -2
262
        else:
263
            try:
264
                index = self.orientation.index(axis)
265
            except ValueError:
266
                index = self.orientation.index(self.flip_axis(axis))
267
            # Return negative indices so that it does not matter whether we
268
            # refer to spatial dimensions or not
269
            index = -3 + index
270
            return index
271
272
    # flake8: noqa: E701
273
    @staticmethod
274
    def flip_axis(axis):
275
        if axis == 'R': return 'L'
276
        elif axis == 'L': return 'R'
277
        elif axis == 'A': return 'P'
278
        elif axis == 'P': return 'A'
279
        elif axis == 'I': return 'S'
280
        elif axis == 'S': return 'I'
281
        else:
282
            values = ', '.join('LRPAISTB')
283
            message = f'Axis not understood. Please use one of: {values}'
284
            raise ValueError(message)
285
286
    def get_spacing_string(self):
287
        strings = [f'{n:.2f}' for n in self.spacing]
288
        string = f'({", ".join(strings)})'
289
        return string
290
291
    def get_bounds(self):
292
        """Get image bounds in mm."""
293
        first_index = 3 * (-0.5,)
294
        last_index = np.array(self.spatial_shape) - 0.5
295
        first_point = nib.affines.apply_affine(self.affine, first_index)
296
        last_point = nib.affines.apply_affine(self.affine, last_index)
297
        array = np.array((first_point, last_point))
298
        bounds_x, bounds_y, bounds_z = array.T.tolist()
299
        return bounds_x, bounds_y, bounds_z
300
301
    @staticmethod
302
    def _parse_single_path(
303
            path: TypePath
304
            ) -> Path:
305
        try:
306
            path = Path(path).expanduser()
307
        except TypeError:
308
            message = (
309
                f'Expected type str or Path but found {path} with '
310
                f'{type(path)} instead'
311
            )
312
            raise TypeError(message)
313
        except RuntimeError:
314
            message = (
315
                f'Conversion to path not possible for variable: {path}'
316
            )
317
            raise RuntimeError(message)
318
319
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
320
            raise FileNotFoundError(f'File not found: "{path}"')
321
        return path
322
323
    def _parse_path(
324
            self,
325
            path: Union[TypePath, Sequence[TypePath]]
326
            ) -> Union[Path, List[Path]]:
327
        if path is None:
328
            return None
329
        if isinstance(path, (str, Path)):
330
            return self._parse_single_path(path)
331
        else:
332
            return [self._parse_single_path(p) for p in path]
333
334
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
335
        if tensor is None:
336
            return None
337
        if isinstance(tensor, np.ndarray):
338
            tensor = torch.from_numpy(tensor.astype(np.float32))
339
        elif isinstance(tensor, torch.Tensor):
340
            tensor = tensor.float()
341
        if tensor.ndim != 4:
342
            raise ValueError('Input tensor must be 4D')
343
        if self.check_nans and torch.isnan(tensor).any():
344
            warnings.warn(f'NaNs found in tensor')
345
        return tensor
346
347
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
348
        return ensure_4d(tensor)
349
350
    @staticmethod
351
    def parse_affine(affine: np.ndarray) -> np.ndarray:
352
        if affine is None:
353
            return np.eye(4)
354
        if not isinstance(affine, np.ndarray):
355
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
356
        if affine.shape != (4, 4):
357
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
358
        return affine
359
360
    def load(self) -> None:
361
        r"""Load the image from disk.
362
363
        Returns:
364
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
365
            :math:`4 \times 4` affine matrix to convert voxel indices to world
366
            coordinates.
367
        """
368
        if self._loaded:
369
            return
370
        paths = self.path if isinstance(self.path, list) else [self.path]
371
        tensor, affine = self.read_and_check(paths[0])
372
        tensors = [tensor]
373
        for path in paths[1:]:
374
            new_tensor, new_affine = self.read_and_check(path)
375
            if not np.array_equal(affine, new_affine):
376
                message = (
377
                    'Files have different affine matrices.'
378
                    f'\nMatrix of {paths[0]}:'
379
                    f'\n{affine}'
380
                    f'\nMatrix of {path}:'
381
                    f'\n{new_affine}'
382
                )
383
                warnings.warn(message, RuntimeWarning)
384
            if not tensor.shape[1:] == new_tensor.shape[1:]:
385
                message = (
386
                    f'Files shape do not match, found {tensor.shape}'
387
                    f'and {new_tensor.shape}'
388
                )
389
                RuntimeError(message)
390
            tensors.append(new_tensor)
391
        tensor = torch.cat(tensors)
392
        self[DATA] = tensor
393
        self[AFFINE] = affine
394
        self._loaded = True
395
396
    def read_and_check(self, path):
397
        tensor, affine = read_image(path)
398
        tensor = self.parse_tensor_shape(tensor)
399
        if self.channels_last:
400
            tensor = tensor.permute(3, 0, 1, 2)
401
        if self.check_nans and torch.isnan(tensor).any():
402
            warnings.warn(f'NaNs found in file "{path}"')
403
        return tensor, affine
404
405
    def save(self, path: TypePath, squeeze: bool = True):
406
        """Save image to disk.
407
408
        Args:
409
            path: String or instance of :py:class:`pathlib.Path`.
410
            squeeze: If ``True``, the singleton dimensions will be removed
411
                before saving.
412
        """
413
        write_image(
414
            self[DATA],
415
            self[AFFINE],
416
            path,
417
            squeeze=squeeze,
418
        )
419
420
    def is_2d(self) -> bool:
421
        return self.shape[-1] == 1
422
423
    def numpy(self) -> np.ndarray:
424
        """Get a NumPy array containing the image data."""
425
        return np.asarray(self)
426
427
    def as_sitk(self, **kwargs) -> sitk.Image:
428
        """Get the image as an instance of :py:class:`sitk.Image`."""
429
        return nib_to_sitk(self[DATA], self[AFFINE], **kwargs)
430
431
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
432
        """Get image center in RAS+ or LPS+ coordinates.
433
434
        Args:
435
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
436
                the first dimension grows towards the left, etc. Otherwise, the
437
                coordinates will be in RAS+ orientation.
438
        """
439
        size = np.array(self.spatial_shape)
440
        center_index = (size - 1) / 2
441
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
442
        if lps:
443
            return (-r, -a, s)
444
        else:
445
            return (r, a, s)
446
447
    def set_check_nans(self, check_nans: bool):
448
        self.check_nans = check_nans
449
450
    def plot(self, **kwargs) -> None:
451
        from ..visualization import plot_image  # avoid circular import
452
        plot_image(self, **kwargs)
453
454
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
455
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
456
        new_affine = self.affine.copy()
457
        new_affine[:3, 3] = new_origin
458
        i0, j0, k0 = index_ini
459
        i1, j1, k1 = index_fin
460
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
461
        kwargs = dict(
462
            tensor=patch,
463
            affine=new_affine,
464
            type=self.type,
465
            path=self.path,
466
        )
467
        for key, value in self.items():
468
            if key in PROTECTED_KEYS: continue
469
            kwargs[key] = value  # should I copy? deepcopy?
470
        return self.__class__(**kwargs)
471
472
473
class ScalarImage(Image):
474
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
475
476
    Example:
477
        >>> import torch
478
        >>> import torchio
479
        >>> # Loading from a file
480
        >>> t1_image = torchio.ScalarImage('t1.nii.gz')
481
        >>> dmri = torchio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
482
        >>> image = torchio.ScalarImage('safe_image.nrrd', check_nans=False)
483
        >>> data, affine = image.data, image.affine
484
        >>> affine.shape
485
        (4, 4)
486
        >>> image.data is image[torchio.DATA]
487
        True
488
        >>> image.data is image.tensor
489
        True
490
        >>> type(image.data)
491
        torch.Tensor
492
493
    See :py:class:`~torchio.Image` for more information.
494
495
    Raises:
496
        ValueError: A :py:attr:`type` is used for instantiation.
497
    """
498
    def __init__(self, *args, **kwargs):
499
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
500
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
501
        kwargs.update({'type': INTENSITY})
502
        super().__init__(*args, **kwargs)
503
504
505
class LabelMap(Image):
506
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
507
508
    Example:
509
        >>> import torch
510
        >>> import torchio
511
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
512
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
513
        >>> tpm = torchio.LabelMap(                     # loading from files
514
        ...     'gray_matter.nii.gz',
515
        ...     'white_matter.nii.gz',
516
        ...     'csf.nii.gz',
517
        ... )
518
519
    See :py:class:`~torchio.data.image.Image` for more information.
520
521
    Raises:
522
        ValueError: If a value for :py:attr:`type` is given.
523
    """
524
    def __init__(self, *args, **kwargs):
525
        if 'type' in kwargs and kwargs['type'] != LABEL:
526
            raise ValueError('Type of LabelMap is always torchio.LABEL')
527
        kwargs.update({'type': LABEL})
528
        super().__init__(*args, **kwargs)
529