Passed
Push — master ( d76542...7bf0dc )
by Fernando
01:06
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 10

Size

Total Lines 48
Code Lines 37

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 10
eloc 37
nop 9
dl 0
loc 48
rs 5.9999
c 0
b 0
f 0

How to fix   Complexity    Many Parameters   

Complexity

Complex classes like torchio.data.image.Image.__init__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file or sequence of paths to files that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read. The data is expected to be 2D or
48
            3D, and may have multiple channels (see :attr:`num_spatial_dims` and
49
            :attr:`channels_last`). If a sequence of paths is given, data
50
            will be concatenated on the channel dimension so spatial
51
            dimensions must match.
52
        type: Type of image, such as :attr:`torchio.INTENSITY` or
53
            :attr:`torchio.LABEL`. This will be used by the transforms to
54
            decide whether to apply an operation, or which interpolation to use
55
            when resampling. For example, `preprocessing`_ and `augmentation`_
56
            intensity transforms will only be applied to images with type
57
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
58
            all types, and nearest neighbor interpolation is always used to
59
            resample images with type :attr:`torchio.LABEL`.
60
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
61
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
62
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
63
            :py:class:`torch.Tensor` or NumPy array with dimensions
64
            :math:`(C, D, H, W)`. If it is not 4D, TorchIO will try to guess
65
            the dimensions meanings. If 2D, the shape will be interpreted as
66
            :math:`(H, W)`. If 3D, the number of spatial dimensions should be
67
            determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims`
68
            is not given and the shape is 3 along the first or last dimensions,
69
            it will be interpreted as a multichannel 2D image. Otherwise, it
70
            be interpreted as a 3D image with a single channel.
71
        affine: If :attr:`path` is not given, :attr:`affine` must be a
72
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
73
            identity matrix.
74
        check_nans: If ``True``, issues a warning if NaNs are found
75
            in the image. If ``False``, images will not be checked for the
76
            presence of NaNs.
77
        num_spatial_dims: If ``2`` and the input tensor has 3 dimensions, it
78
            will be interpreted as a multichannel 2D image. If ``3`` and the
79
            input has 3 dimensions, it will be interpreted as a
80
            single-channel 3D volume.
81
        channels_last: If ``True``, the last dimension of the input will be
82
            interpreted as the channels. Defaults to ``True`` if :attr:`path` is
83
            given and ``False`` otherwise.
84
        **kwargs: Items that will be added to the image dictionary, e.g.
85
            acquisition parameters.
86
87
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
88
    when needed.
89
90
    Example:
91
        >>> import torchio
92
        >>> image = torchio.ScalarImage('t1.nii.gz')  # subclass of Image
93
        >>> image  # not loaded yet
94
        ScalarImage(path: t1.nii.gz; type: intensity)
95
        >>> times_two = 2 * image.data  # data is loaded and cached here
96
        >>> image
97
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
98
        >>> image.save('doubled_image.nii.gz')
99
100
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
101
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
102
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
103
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
104
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
105
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
106
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
107
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
108
    """
109
    def __init__(
110
            self,
111
            path: Union[TypePath, Sequence[TypePath], None] = None,
112
            type: str = None,
113
            tensor: Optional[TypeData] = None,
114
            affine: Optional[TypeData] = None,
115
            check_nans: bool = True,
116
            num_spatial_dims: Optional[int] = None,
117
            channels_last: Optional[bool] = None,
118
            **kwargs: Dict[str, Any],
119
            ):
120
        self.check_nans = check_nans
121
        self.num_spatial_dims = num_spatial_dims
122
123
        if type is None:
124
            warnings.warn(
125
                'Not specifying the image type is deprecated and will be'
126
                ' mandatory in the future. You can probably use ScalarImage or'
127
                ' LabelMap instead'
128
            )
129
            type = INTENSITY
130
131
        if path is None and tensor is None:
132
            raise ValueError('A value for path or tensor must be given')
133
        self._loaded = False
134
135
        # Number of channels are typically stored in the last dimensions in disk
136
        # But if a tensor is given, the channels should be in the first dim
137
        if channels_last is None:
138
            channels_last = path is not None
139
        self.channels_last = channels_last
140
141
        tensor = self.parse_tensor(tensor)
142
        affine = self.parse_affine(affine)
143
        if tensor is not None:
144
            self[DATA] = tensor
145
            self[AFFINE] = affine
146
            self._loaded = True
147
        for key in PROTECTED_KEYS:
148
            if key in kwargs:
149
                message = f'Key "{key}" is reserved. Use a different one'
150
                raise ValueError(message)
151
152
        super().__init__(**kwargs)
153
        self.path = self._parse_path(path)
154
        self[PATH] = '' if self.path is None else str(self.path)
155
        self[STEM] = '' if self.path is None else get_stem(self.path)
156
        self[TYPE] = type
157
158
    def __repr__(self):
159
        properties = []
160
        if self._loaded:
161
            properties.extend([
162
                f'shape: {self.shape}',
163
                f'spacing: {self.get_spacing_string()}',
164
                f'orientation: {"".join(self.orientation)}+',
165
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
166
            ])
167
        else:
168
            properties.append(f'path: "{self.path}"')
169
        properties.append(f'type: {self.type}')
170
        properties = '; '.join(properties)
171
        string = f'{self.__class__.__name__}({properties})'
172
        return string
173
174
    def __getitem__(self, item):
175
        if item in (DATA, AFFINE):
176
            if item not in self:
177
                self._load()
178
        return super().__getitem__(item)
179
180
    def __array__(self):
181
        return self[DATA].numpy()
182
183
    def __copy__(self):
184
        kwargs = dict(
185
            tensor=self.data,
186
            affine=self.affine,
187
            type=self.type,
188
            path=self.path,
189
            channels_last=False,
190
        )
191
        for key, value in self.items():
192
            if key in PROTECTED_KEYS: continue
193
            kwargs[key] = value  # should I copy? deepcopy?
194
        return self.__class__(**kwargs)
195
196
    @property
197
    def data(self):
198
        return self[DATA]
199
200
    @property
201
    def tensor(self):
202
        return self.data
203
204
    @property
205
    def affine(self):
206
        return self[AFFINE]
207
208
    @property
209
    def type(self):
210
        return self[TYPE]
211
212
    @property
213
    def shape(self) -> Tuple[int, int, int, int]:
214
        return tuple(self.data.shape)
215
216
    @property
217
    def spatial_shape(self) -> TypeTripletInt:
218
        return self.shape[1:]
219
220
    @property
221
    def orientation(self):
222
        return nib.aff2axcodes(self.affine)
223
224
    @property
225
    def spacing(self):
226
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
227
        return tuple(spacing)
228
229
    @property
230
    def memory(self):
231
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
232
233
    def axis_name_to_index(self, axis: str):
234
        """Convert an axis name to an axis index.
235
236
        Args:
237
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
238
            ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case versions
239
            and first letters are also valid, as only the first letter will be
240
            used.
241
242
        .. note:: If you are working with animals, you should probably use
243
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
244
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
245
            respectively.
246
        """
247
        if not isinstance(axis, str):
248
            raise ValueError('Axis must be a string')
249
        axis = axis[0].upper()
250
251
        # Generally, TorchIO tensors are (C, D, H, W)
252
        if axis == 'H':
253
            return -2
254
        elif axis == 'W':
255
            return -1
256
        else:
257
            try:
258
                index = self.orientation.index(axis)
259
            except ValueError:
260
                index = self.orientation.index(self.flip_axis(axis))
261
            # Return negative indices so that it does not matter whether we
262
            # refer to spatial dimensions or not
263
            index = -4 + index
264
            return index
265
266
    @staticmethod
267
    def flip_axis(axis):
268
        if axis == 'R': return 'L'
269
        elif axis == 'L': return 'R'
270
        elif axis == 'A': return 'P'
271
        elif axis == 'P': return 'A'
272
        elif axis == 'I': return 'S'
273
        elif axis == 'S': return 'I'
274
        else:
275
            message = (
276
                f'Axis not understood. Please use a value in {tuple("LRAPIS")}'
277
            )
278
            raise ValueError(message)
279
280
    def get_spacing_string(self):
281
        strings = [f'{n:.2f}' for n in self.spacing]
282
        string = f'({", ".join(strings)})'
283
        return string
284
285
    def get_bounds(self):
286
        """Get image bounds in mm."""
287
        first_index = 3 * (-0.5,)
288
        last_index = np.array(self.spatial_shape) - 0.5
289
        first_point = nib.affines.apply_affine(self.affine, first_index)
290
        last_point = nib.affines.apply_affine(self.affine, last_index)
291
        array = np.array((first_point, last_point))
292
        bounds_x, bounds_y, bounds_z = array.T.tolist()
293
        return bounds_x, bounds_y, bounds_z
294
295
    @staticmethod
296
    def _parse_single_path(
297
            path: TypePath
298
            ) -> Path:
299
        try:
300
            path = Path(path).expanduser()
301
        except TypeError:
302
            message = (
303
                f'Expected type str or Path but found {path} with '
304
                f'{type(path)} instead'
305
            )
306
            raise TypeError(message)
307
        except RuntimeError:
308
            message = (
309
                f'Conversion to path not possible for variable: {path}'
310
            )
311
            raise RuntimeError(message)
312
313
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
314
            raise FileNotFoundError(f'File not found: {path}')
315
        return path
316
317
    def _parse_path(
318
            self,
319
            path: Union[TypePath, Sequence[TypePath]]
320
            ) -> Union[Path, List[Path]]:
321
        if path is None:
322
            return None
323
        if isinstance(path, (str, Path)):
324
            return self._parse_single_path(path)
325
        else:
326
            return [self._parse_single_path(p) for p in path]
327
328
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
329
        if tensor is None:
330
            return None
331
        if isinstance(tensor, np.ndarray):
332
            tensor = torch.from_numpy(tensor.astype(np.float32))
333
        elif isinstance(tensor, torch.Tensor):
334
            tensor = tensor.float()
335
        tensor = self.parse_tensor_shape(tensor)
336
        if self.check_nans and torch.isnan(tensor).any():
337
            warnings.warn(f'NaNs found in tensor')
338
        return tensor
339
340
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
341
        return ensure_4d(tensor, self.channels_last, self.num_spatial_dims)
342
343
    @staticmethod
344
    def parse_affine(affine: np.ndarray) -> np.ndarray:
345
        if affine is None:
346
            return np.eye(4)
347
        if not isinstance(affine, np.ndarray):
348
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
349
        if affine.shape != (4, 4):
350
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
351
        return affine
352
353
    def _load(self) -> None:
354
        r"""Load the image from disk.
355
356
        Returns:
357
            Tuple containing a 4D tensor of size :math:`(C, D, H, W)` and a 2D
358
            :math:`4 \times 4` affine matrix to convert voxel indices to world
359
            coordinates.
360
        """
361
        if self._loaded:
362
            return
363
364
        paths = self.path if isinstance(self.path, list) else [self.path]
365
        tensor, affine = read_image(paths[0])
366
        tensor = self.parse_tensor_shape(tensor)
367
368
        if self.check_nans and torch.isnan(tensor).any():
369
            warnings.warn(f'NaNs found in file "{paths[0]}"')
370
371
        tensors = [tensor]
372
        for path in paths[1:]:
373
            new_tensor, new_affine = read_image(path)
374
            new_tensor = self.parse_tensor_shape(new_tensor)
375
376
            if self.check_nans and torch.isnan(tensor).any():
377
                warnings.warn(f'NaNs found in file "{path}"')
378
379
            if not np.array_equal(affine, new_affine):
380
                message = 'Files have different affine matrices'
381
                warnings.warn(message, RuntimeWarning)
382
383
            if not tensor.shape[1:] == new_tensor.shape[1:]:
384
                message = (
385
                    f'Files shape do not match, found {tensor.shape}'
386
                    f'and {new_tensor.shape}'
387
                )
388
                RuntimeError(message)
389
390
            tensors.append(new_tensor)
391
392
        tensor = torch.cat(tensors)
393
394
        self[DATA] = tensor
395
        self[AFFINE] = affine
396
        self._loaded = True
397
398
    def save(self, path, squeeze=True, channels_last=True):
399
        """Save image to disk.
400
401
        Args:
402
            path: String or instance of :py:class:`pathlib.Path`.
403
            squeeze: If ``True``, the singleton dimensions will be removed
404
                before saving.
405
            channels_last: If ``True``, the channels will be saved in the last
406
                dimension.
407
        """
408
        write_image(
409
            self[DATA],
410
            self[AFFINE],
411
            path,
412
            squeeze=squeeze,
413
            channels_last=channels_last,
414
        )
415
416
    def is_2d(self) -> bool:
417
        return self.shape[-3] == 1
418
419
    def numpy(self) -> np.ndarray:
420
        """Get a NumPy array containing the image data."""
421
        return np.asarray(self)
422
423
    def as_sitk(self) -> sitk.Image:
424
        """Get the image as an instance of :py:class:`sitk.Image`."""
425
        return nib_to_sitk(self[DATA], self[AFFINE])
426
427
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
428
        """Get image center in RAS+ or LPS+ coordinates.
429
430
        Args:
431
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
432
                the first dimension grows towards the left, etc. Otherwise, the
433
                coordinates will be in RAS+ orientation.
434
        """
435
        size = np.array(self.spatial_shape)
436
        center_index = (size - 1) / 2
437
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
438
        if lps:
439
            return (-r, -a, s)
440
        else:
441
            return (r, a, s)
442
443
    def set_check_nans(self, check_nans: bool):
444
        self.check_nans = check_nans
445
446
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
447
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
448
        new_affine = self.affine.copy()
449
        new_affine[:3, 3] = new_origin
450
        i0, j0, k0 = index_ini
451
        i1, j1, k1 = index_fin
452
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
453
        kwargs = dict(
454
            tensor=patch,
455
            affine=new_affine,
456
            type=self.type,
457
            path=self.path,
458
            channels_last=False,
459
        )
460
        for key, value in self.items():
461
            if key in PROTECTED_KEYS: continue
462
            kwargs[key] = value  # should I copy? deepcopy?
463
        return self.__class__(**kwargs)
464
465
466
class ScalarImage(Image):
467
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
468
469
    Example:
470
        >>> import torch
471
        >>> import torchio
472
        >>> # Loading from a file
473
        >>> t1_image = torchio.ScalarImage('t1.nii.gz')
474
        >>> dmri = torchio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
475
        >>> image = torchio.ScalarImage('safe_image.nrrd', check_nans=False)
476
        >>> data, affine = image.data, image.affine
477
        >>> affine.shape
478
        (4, 4)
479
        >>> image.data is image[torchio.DATA]
480
        True
481
        >>> image.data is image.tensor
482
        True
483
        >>> type(image.data)
484
        torch.Tensor
485
486
    See :py:class:`~torchio.Image` for more information.
487
488
    Raises:
489
        ValueError: A :py:attr:`type` is used for instantiation.
490
    """
491
    def __init__(self, *args, **kwargs):
492
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
493
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
494
        kwargs.update({'type': INTENSITY})
495
        super().__init__(*args, **kwargs)
496
497
498
class LabelMap(Image):
499
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
500
501
    Example:
502
        >>> import torch
503
        >>> import torchio
504
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
505
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
506
        >>> tpm = torchio.LabelMap(                     # loading from files
507
        ...     'gray_matter.nii.gz',
508
        ...     'white_matter.nii.gz',
509
        ...     'csf.nii.gz',
510
        ... )
511
512
    See :py:class:`~torchio.data.image.Image` for more information.
513
514
    Raises:
515
        ValueError: If a value for :py:attr:`type` is given.
516
    """
517
    def __init__(self, *args, **kwargs):
518
        if 'type' in kwargs and kwargs['type'] != LABEL:
519
            raise ValueError('Type of LabelMap is always torchio.LABEL')
520
        kwargs.update({'type': LABEL})
521
        super().__init__(*args, **kwargs)
522