Passed
Push — master ( 53ab14...c2608f )
by Fernando
01:07
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 10

Size

Total Lines 48
Code Lines 37

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 10
eloc 37
nop 9
dl 0
loc 48
rs 5.9999
c 0
b 0
f 0

How to fix   Complexity    Many Parameters   

Complexity

Complex classes like torchio.data.image.Image.__init__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file or sequence of paths to files that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read. The data is expected to be 2D or
48
            3D, and may have multiple channels (see :attr:`num_spatial_dims` and
49
            :attr:`channels_last`). If a sequence of paths is given, data
50
            will be concatenated on the channel dimension so spatial
51
            dimensions must match.
52
        type: Type of image, such as :attr:`torchio.INTENSITY` or
53
            :attr:`torchio.LABEL`. This will be used by the transforms to
54
            decide whether to apply an operation, or which interpolation to use
55
            when resampling. For example, `preprocessing`_ and `augmentation`_
56
            intensity transforms will only be applied to images with type
57
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
58
            all types, and nearest neighbor interpolation is always used to
59
            resample images with type :attr:`torchio.LABEL`.
60
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
61
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
62
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
63
            :py:class:`torch.Tensor` or NumPy array with dimensions
64
            :math:`(C, H, W, D)`. If it is not 4D, TorchIO will try to guess
65
            the dimensions meanings. If 2D, the shape will be interpreted as
66
            :math:`(H, W)`. If 3D, the number of spatial dimensions should be
67
            determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims`
68
            is not given and the shape is 3 along the first or last dimensions,
69
            it will be interpreted as a multichannel 2D image. Otherwise, it
70
            be interpreted as a 3D image with a single channel.
71
        affine: If :attr:`path` is not given, :attr:`affine` must be a
72
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
73
            identity matrix.
74
        check_nans: If ``True``, issues a warning if NaNs are found
75
            in the image. If ``False``, images will not be checked for the
76
            presence of NaNs.
77
        num_spatial_dims: If ``2`` and the input tensor has 3 dimensions, it
78
            will be interpreted as a multichannel 2D image. If ``3`` and the
79
            input has 3 dimensions, it will be interpreted as a
80
            single-channel 3D volume.
81
        channels_last: If ``True``, the last dimension of the input will be
82
            interpreted as the channels. Defaults to ``True`` if :attr:`path` is
83
            given and ``False`` otherwise.
84
        **kwargs: Items that will be added to the image dictionary, e.g.
85
            acquisition parameters.
86
87
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
88
    when needed.
89
90
    Example:
91
        >>> import torchio
92
        >>> image = torchio.ScalarImage('t1.nii.gz')  # subclass of Image
93
        >>> image  # not loaded yet
94
        ScalarImage(path: t1.nii.gz; type: intensity)
95
        >>> times_two = 2 * image.data  # data is loaded and cached here
96
        >>> image
97
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
98
        >>> image.save('doubled_image.nii.gz')
99
100
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
101
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
102
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
103
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
104
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
105
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
106
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
107
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
108
    """
109
    def __init__(
110
            self,
111
            path: Union[TypePath, Sequence[TypePath], None] = None,
112
            type: str = None,
113
            tensor: Optional[TypeData] = None,
114
            affine: Optional[TypeData] = None,
115
            check_nans: bool = True,
116
            num_spatial_dims: Optional[int] = None,
117
            channels_last: Optional[bool] = None,
118
            **kwargs: Dict[str, Any],
119
            ):
120
        self.check_nans = check_nans
121
        self.num_spatial_dims = num_spatial_dims
122
123
        if type is None:
124
            warnings.warn(
125
                'Not specifying the image type is deprecated and will be'
126
                ' mandatory in the future. You can probably use ScalarImage or'
127
                ' LabelMap instead'
128
            )
129
            type = INTENSITY
130
131
        if path is None and tensor is None:
132
            raise ValueError('A value for path or tensor must be given')
133
        self._loaded = False
134
135
        # Number of channels are typically stored in the last dimensions in disk
136
        # But if a tensor is given, the channels should be in the first dim
137
        if channels_last is None:
138
            channels_last = path is not None
139
        self.channels_last = channels_last
140
141
        tensor = self.parse_tensor(tensor)
142
        affine = self.parse_affine(affine)
143
        if tensor is not None:
144
            self[DATA] = tensor
145
            self[AFFINE] = affine
146
            self._loaded = True
147
        for key in PROTECTED_KEYS:
148
            if key in kwargs:
149
                message = f'Key "{key}" is reserved. Use a different one'
150
                raise ValueError(message)
151
152
        super().__init__(**kwargs)
153
        self.path = self._parse_path(path)
154
        self[PATH] = '' if self.path is None else str(self.path)
155
        self[STEM] = '' if self.path is None else get_stem(self.path)
156
        self[TYPE] = type
157
158
    def __repr__(self):
159
        properties = []
160
        if self._loaded:
161
            properties.extend([
162
                f'shape: {self.shape}',
163
                f'spacing: {self.get_spacing_string()}',
164
                f'orientation: {"".join(self.orientation)}+',
165
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
166
            ])
167
        else:
168
            properties.append(f'path: "{self.path}"')
169
        properties.append(f'type: {self.type}')
170
        properties = '; '.join(properties)
171
        string = f'{self.__class__.__name__}({properties})'
172
        return string
173
174
    def __getitem__(self, item):
175
        if item in (DATA, AFFINE):
176
            if item not in self:
177
                self.load()
178
        return super().__getitem__(item)
179
180
    def __array__(self):
181
        return self[DATA].numpy()
182
183
    def __copy__(self):
184
        kwargs = dict(
185
            tensor=self.data,
186
            affine=self.affine,
187
            type=self.type,
188
            path=self.path,
189
            channels_last=False,
190
        )
191
        for key, value in self.items():
192
            if key in PROTECTED_KEYS: continue
193
            kwargs[key] = value  # should I copy? deepcopy?
194
        return self.__class__(**kwargs)
195
196
    @property
197
    def data(self):
198
        return self[DATA]
199
200
    @property
201
    def tensor(self):
202
        return self.data
203
204
    @property
205
    def affine(self):
206
        return self[AFFINE]
207
208
    @property
209
    def type(self):
210
        return self[TYPE]
211
212
    @property
213
    def shape(self) -> Tuple[int, int, int, int]:
214
        return tuple(self.data.shape)
215
216
    @property
217
    def spatial_shape(self) -> TypeTripletInt:
218
        return self.shape[1:]
219
220
    def check_is_2d(self):
221
        if not self.is_2d():
222
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
223
            raise RuntimeError(message)
224
225
    @property
226
    def height(self) -> int:
227
        self.check_is_2d()
228
        return self.spatial_shape[0]
229
230
    @property
231
    def width(self) -> int:
232
        self.check_is_2d()
233
        return self.spatial_shape[1]
234
235
    @property
236
    def orientation(self):
237
        return nib.aff2axcodes(self.affine)
238
239
    @property
240
    def spacing(self):
241
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
242
        return tuple(spacing)
243
244
    @property
245
    def memory(self):
246
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
247
248
    def axis_name_to_index(self, axis: str):
249
        """Convert an axis name to an axis index.
250
251
        Args:
252
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
253
            ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case versions
254
            and first letters are also valid, as only the first letter will be
255
            used.
256
257
        .. note:: If you are working with animals, you should probably use
258
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
259
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
260
            respectively.
261
        """
262
        if not isinstance(axis, str):
263
            raise ValueError('Axis must be a string')
264
        axis = axis[0].upper()
265
266
        # Generally, TorchIO tensors are (C, H, W, D)
267
        if axis == 'H':
268
            return 1
269
        elif axis == 'W':
270
            return 2
271
        else:
272
            try:
273
                index = self.orientation.index(axis)
274
            except ValueError:
275
                index = self.orientation.index(self.flip_axis(axis))
276
            # Return negative indices so that it does not matter whether we
277
            # refer to spatial dimensions or not
278
            index = -4 + index
279
            return index
280
281
    @staticmethod
282
    def flip_axis(axis):
283
        if axis == 'R': return 'L'
284
        elif axis == 'L': return 'R'
285
        elif axis == 'A': return 'P'
286
        elif axis == 'P': return 'A'
287
        elif axis == 'I': return 'S'
288
        elif axis == 'S': return 'I'
289
        else:
290
            message = (
291
                f'Axis not understood. Please use a value in {tuple("LRAPIS")}'
292
            )
293
            raise ValueError(message)
294
295
    def get_spacing_string(self):
296
        strings = [f'{n:.2f}' for n in self.spacing]
297
        string = f'({", ".join(strings)})'
298
        return string
299
300
    def get_bounds(self):
301
        """Get image bounds in mm."""
302
        first_index = 3 * (-0.5,)
303
        last_index = np.array(self.spatial_shape) - 0.5
304
        first_point = nib.affines.apply_affine(self.affine, first_index)
305
        last_point = nib.affines.apply_affine(self.affine, last_index)
306
        array = np.array((first_point, last_point))
307
        bounds_x, bounds_y, bounds_z = array.T.tolist()
308
        return bounds_x, bounds_y, bounds_z
309
310
    @staticmethod
311
    def _parse_single_path(
312
            path: TypePath
313
            ) -> Path:
314
        try:
315
            path = Path(path).expanduser()
316
        except TypeError:
317
            message = (
318
                f'Expected type str or Path but found {path} with '
319
                f'{type(path)} instead'
320
            )
321
            raise TypeError(message)
322
        except RuntimeError:
323
            message = (
324
                f'Conversion to path not possible for variable: {path}'
325
            )
326
            raise RuntimeError(message)
327
328
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
329
            raise FileNotFoundError(f'File not found: {path}')
330
        return path
331
332
    def _parse_path(
333
            self,
334
            path: Union[TypePath, Sequence[TypePath]]
335
            ) -> Union[Path, List[Path]]:
336
        if path is None:
337
            return None
338
        if isinstance(path, (str, Path)):
339
            return self._parse_single_path(path)
340
        else:
341
            return [self._parse_single_path(p) for p in path]
342
343
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
344
        if tensor is None:
345
            return None
346
        if isinstance(tensor, np.ndarray):
347
            tensor = torch.from_numpy(tensor.astype(np.float32))
348
        elif isinstance(tensor, torch.Tensor):
349
            tensor = tensor.float()
350
        tensor = self.parse_tensor_shape(tensor)
351
        if self.check_nans and torch.isnan(tensor).any():
352
            warnings.warn(f'NaNs found in tensor')
353
        return tensor
354
355
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
356
        return ensure_4d(tensor, self.channels_last, self.num_spatial_dims)
357
358
    @staticmethod
359
    def parse_affine(affine: np.ndarray) -> np.ndarray:
360
        if affine is None:
361
            return np.eye(4)
362
        if not isinstance(affine, np.ndarray):
363
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
364
        if affine.shape != (4, 4):
365
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
366
        return affine
367
368
    def load(self) -> None:
369
        r"""Load the image from disk.
370
371
        Returns:
372
            Tuple containing a 4D tensor of size :math:`(C, H, W, D)` and a 2D
373
            :math:`4 \times 4` affine matrix to convert voxel indices to world
374
            coordinates.
375
        """
376
        if self._loaded:
377
            return
378
379
        paths = self.path if isinstance(self.path, list) else [self.path]
380
        tensor, affine = read_image(paths[0])
381
        tensor = self.parse_tensor_shape(tensor)
382
383
        if self.check_nans and torch.isnan(tensor).any():
384
            warnings.warn(f'NaNs found in file "{paths[0]}"')
385
386
        tensors = [tensor]
387
        for path in paths[1:]:
388
            new_tensor, new_affine = read_image(path)
389
            new_tensor = self.parse_tensor_shape(new_tensor)
390
391
            if self.check_nans and torch.isnan(tensor).any():
392
                warnings.warn(f'NaNs found in file "{path}"')
393
394
            if not np.array_equal(affine, new_affine):
395
                message = (
396
                    'Files have different affine matrices.'
397
                    f'\nMatrix of {paths[0]}:'
398
                    f'\n{affine}'
399
                    f'\nMatrix of {path}:'
400
                    f'\n{new_affine}'
401
                )
402
                warnings.warn(message, RuntimeWarning)
403
404
            if not tensor.shape[1:] == new_tensor.shape[1:]:
405
                message = (
406
                    f'Files shape do not match, found {tensor.shape}'
407
                    f'and {new_tensor.shape}'
408
                )
409
                RuntimeError(message)
410
411
            tensors.append(new_tensor)
412
413
        tensor = torch.cat(tensors)
414
415
        self[DATA] = tensor
416
        self[AFFINE] = affine
417
        self._loaded = True
418
419
    def save(self, path, squeeze=True, channels_last=True):
420
        """Save image to disk.
421
422
        Args:
423
            path: String or instance of :py:class:`pathlib.Path`.
424
            squeeze: If ``True``, the singleton dimensions will be removed
425
                before saving.
426
            channels_last: If ``True``, the channels will be saved in the last
427
                dimension.
428
        """
429
        write_image(
430
            self[DATA],
431
            self[AFFINE],
432
            path,
433
            squeeze=squeeze,
434
            channels_last=channels_last,
435
        )
436
437
    def is_2d(self) -> bool:
438
        return self.shape[-1] == 1
439
440
    def numpy(self) -> np.ndarray:
441
        """Get a NumPy array containing the image data."""
442
        return np.asarray(self)
443
444
    def as_sitk(self) -> sitk.Image:
445
        """Get the image as an instance of :py:class:`sitk.Image`."""
446
        return nib_to_sitk(self[DATA], self[AFFINE])
447
448
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
449
        """Get image center in RAS+ or LPS+ coordinates.
450
451
        Args:
452
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
453
                the first dimension grows towards the left, etc. Otherwise, the
454
                coordinates will be in RAS+ orientation.
455
        """
456
        size = np.array(self.spatial_shape)
457
        center_index = (size - 1) / 2
458
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
459
        if lps:
460
            return (-r, -a, s)
461
        else:
462
            return (r, a, s)
463
464
    def set_check_nans(self, check_nans: bool):
465
        self.check_nans = check_nans
466
467
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
468
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
469
        new_affine = self.affine.copy()
470
        new_affine[:3, 3] = new_origin
471
        i0, j0, k0 = index_ini
472
        i1, j1, k1 = index_fin
473
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
474
        kwargs = dict(
475
            tensor=patch,
476
            affine=new_affine,
477
            type=self.type,
478
            path=self.path,
479
            channels_last=False,
480
        )
481
        for key, value in self.items():
482
            if key in PROTECTED_KEYS: continue
483
            kwargs[key] = value  # should I copy? deepcopy?
484
        return self.__class__(**kwargs)
485
486
487
class ScalarImage(Image):
488
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
489
490
    Example:
491
        >>> import torch
492
        >>> import torchio
493
        >>> # Loading from a file
494
        >>> t1_image = torchio.ScalarImage('t1.nii.gz')
495
        >>> dmri = torchio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
496
        >>> image = torchio.ScalarImage('safe_image.nrrd', check_nans=False)
497
        >>> data, affine = image.data, image.affine
498
        >>> affine.shape
499
        (4, 4)
500
        >>> image.data is image[torchio.DATA]
501
        True
502
        >>> image.data is image.tensor
503
        True
504
        >>> type(image.data)
505
        torch.Tensor
506
507
    See :py:class:`~torchio.Image` for more information.
508
509
    Raises:
510
        ValueError: A :py:attr:`type` is used for instantiation.
511
    """
512
    def __init__(self, *args, **kwargs):
513
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
514
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
515
        kwargs.update({'type': INTENSITY})
516
        super().__init__(*args, **kwargs)
517
518
519
class LabelMap(Image):
520
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
521
522
    Example:
523
        >>> import torch
524
        >>> import torchio
525
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
526
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
527
        >>> tpm = torchio.LabelMap(                     # loading from files
528
        ...     'gray_matter.nii.gz',
529
        ...     'white_matter.nii.gz',
530
        ...     'csf.nii.gz',
531
        ... )
532
533
    See :py:class:`~torchio.data.image.Image` for more information.
534
535
    Raises:
536
        ValueError: If a value for :py:attr:`type` is given.
537
    """
538
    def __init__(self, *args, **kwargs):
539
        if 'type' in kwargs and kwargs['type'] != LABEL:
540
            raise ValueError('Type of LabelMap is always torchio.LABEL')
541
        kwargs.update({'type': LABEL})
542
        super().__init__(*args, **kwargs)
543