Passed
Pull Request — master (#286)
by Fernando
01:09
created

torchio.data.image   F

Complexity

Total Complexity 90

Size/Duplication

Total Lines 542
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 296
dl 0
loc 542
rs 2
c 0
b 0
f 0
wmc 90

36 Methods

Rating   Name   Duplication   Size   Complexity  
A Image.parse_affine() 0 9 4
A Image.get_center() 0 15 2
A Image.set_check_nans() 0 2 1
A Image.numpy() 0 3 1
A Image.is_2d() 0 2 1
A Image.as_sitk() 0 3 1
A Image.spatial_shape() 0 3 1
A Image.__array__() 0 2 1
A Image._parse_single_path() 0 21 5
B Image.flip_axis() 0 12 7
B Image.parse_tensor() 0 12 7
A Image.get_bounds() 0 9 1
A Image.data() 0 3 1
A Image.save() 0 13 1
A LabelMap.__init__() 0 5 3
A Image.memory() 0 3 1
A Image.__repr__() 0 15 2
A Image.type() 0 3 1
A Image.__copy__() 0 11 3
A Image.crop() 0 17 3
A Image._parse_path() 0 10 3
C Image.__init__() 0 42 9
C Image.load() 0 53 10
A Image.check_is_2d() 0 4 2
A Image.shape() 0 3 1
A Image.__getitem__() 0 5 3
A Image.orientation() 0 3 1
A Image.width() 0 4 1
A Image.tensor() 0 3 1
A Image.parse_tensor_shape() 0 6 1
A Image.spacing() 0 4 1
A Image.get_spacing_string() 0 4 1
A ScalarImage.__init__() 0 5 3
A Image.axis_name_to_index() 0 36 4
A Image.height() 0 4 1
A Image.affine() 0 3 1

How to fix   Complexity   

Complexity

Complex classes like torchio.data.image often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional, Union, Sequence, List
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file or sequence of paths to files that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read. The data is expected to be 2D or
48
            3D, and may have multiple channels (see :attr:`num_spatial_dims`).
49
            If a sequence of paths is given, data
50
            will be concatenated on the channel dimension so spatial
51
            dimensions must match. If the read tensor is not 4D,
52
            TorchIO will try to guess the dimensions meanings.
53
            If 2D, the shape will be interpreted as
54
            :math:`(W, H)`. If 3D, the number of spatial dimensions should be
55
            determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims`
56
            is not given and the shape is 3 along the first or last dimensions,
57
            it will be interpreted as a multichannel 2D image. Otherwise, it
58
            be interpreted as a 3D image with a single channel.
59
        type: Type of image, such as :attr:`torchio.INTENSITY` or
60
            :attr:`torchio.LABEL`. This will be used by the transforms to
61
            decide whether to apply an operation, or which interpolation to use
62
            when resampling. For example, `preprocessing`_ and `augmentation`_
63
            intensity transforms will only be applied to images with type
64
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
65
            all types, and nearest neighbor interpolation is always used to
66
            resample images with type :attr:`torchio.LABEL`.
67
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
68
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
69
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
70
            :py:class:`torch.Tensor` or NumPy array with dimensions
71
            :math:`(C, W, H, D)`.
72
        affine: If :attr:`path` is not given, :attr:`affine` must be a
73
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
74
            identity matrix.
75
        check_nans: If ``True``, issues a warning if NaNs are found
76
            in the image. If ``False``, images will not be checked for the
77
            presence of NaNs.
78
        num_spatial_dims: If ``2`` and the input tensor has 3 dimensions, it
79
            will be interpreted as a multichannel 2D image. If ``3`` and the
80
            input has 3 dimensions, it will be interpreted as a
81
            single-channel 3D volume.
82
        **kwargs: Items that will be added to the image dictionary, e.g.
83
            acquisition parameters.
84
85
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
86
    when needed.
87
88
    Example:
89
        >>> import torchio
90
        >>> image = torchio.ScalarImage('t1.nii.gz')  # subclass of Image
91
        >>> image  # not loaded yet
92
        ScalarImage(path: t1.nii.gz; type: intensity)
93
        >>> times_two = 2 * image.data  # data is loaded and cached here
94
        >>> image
95
        ScalarImage(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
96
        >>> image.save('doubled_image.nii.gz')
97
98
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
99
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
100
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
101
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
102
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
103
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
104
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
105
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
106
    """
107
    def __init__(
108
            self,
109
            path: Union[TypePath, Sequence[TypePath], None] = None,
110
            type: str = None,
111
            tensor: Optional[TypeData] = None,
112
            affine: Optional[TypeData] = None,
113
            check_nans: bool = True,
114
            num_spatial_dims: Optional[int] = None,
115
            **kwargs: Dict[str, Any],
116
            ):
117
        self.check_nans = check_nans
118
        self.num_spatial_dims = num_spatial_dims
119
120
        if type is None:
121
            warnings.warn(
122
                'Not specifying the image type is deprecated and will be'
123
                ' mandatory in the future. You can probably use ScalarImage or'
124
                ' LabelMap instead'
125
            )
126
            type = INTENSITY
127
128
        if path is None and tensor is None:
129
            raise ValueError('A value for path or tensor must be given')
130
        self._loaded = False
131
132
        tensor = self.parse_tensor(tensor)
133
        affine = self.parse_affine(affine)
134
        if tensor is not None:
135
            self[DATA] = tensor
136
            self[AFFINE] = affine
137
            self._loaded = True
138
        for key in PROTECTED_KEYS:
139
            if key in kwargs:
140
                message = f'Key "{key}" is reserved. Use a different one'
141
                raise ValueError(message)
142
143
        super().__init__(**kwargs)
144
        self.path = self._parse_path(path)
145
146
        self[PATH] = '' if self.path is None else str(self.path)
147
        self[STEM] = '' if self.path is None else get_stem(self.path)
148
        self[TYPE] = type
149
150
    def __repr__(self):
151
        properties = []
152
        if self._loaded:
153
            properties.extend([
154
                f'shape: {self.shape}',
155
                f'spacing: {self.get_spacing_string()}',
156
                f'orientation: {"".join(self.orientation)}+',
157
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
158
            ])
159
        else:
160
            properties.append(f'path: "{self.path}"')
161
        properties.append(f'type: {self.type}')
162
        properties = '; '.join(properties)
163
        string = f'{self.__class__.__name__}({properties})'
164
        return string
165
166
    def __getitem__(self, item):
167
        if item in (DATA, AFFINE):
168
            if item not in self:
169
                self.load()
170
        return super().__getitem__(item)
171
172
    def __array__(self):
173
        return self[DATA].numpy()
174
175
    def __copy__(self):
176
        kwargs = dict(
177
            tensor=self.data,
178
            affine=self.affine,
179
            type=self.type,
180
            path=self.path,
181
        )
182
        for key, value in self.items():
183
            if key in PROTECTED_KEYS: continue
184
            kwargs[key] = value  # should I copy? deepcopy?
185
        return self.__class__(**kwargs)
186
187
    @property
188
    def data(self):
189
        return self[DATA]
190
191
    @property
192
    def tensor(self):
193
        return self.data
194
195
    @property
196
    def affine(self):
197
        return self[AFFINE]
198
199
    @property
200
    def type(self):
201
        return self[TYPE]
202
203
    @property
204
    def shape(self) -> Tuple[int, int, int, int]:
205
        return tuple(self.data.shape)
206
207
    @property
208
    def spatial_shape(self) -> TypeTripletInt:
209
        return self.shape[1:]
210
211
    def check_is_2d(self):
212
        if not self.is_2d():
213
            message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
214
            raise RuntimeError(message)
215
216
    @property
217
    def height(self) -> int:
218
        self.check_is_2d()
219
        return self.spatial_shape[1]
220
221
    @property
222
    def width(self) -> int:
223
        self.check_is_2d()
224
        return self.spatial_shape[0]
225
226
    @property
227
    def orientation(self):
228
        return nib.aff2axcodes(self.affine)
229
230
    @property
231
    def spacing(self):
232
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
233
        return tuple(spacing)
234
235
    @property
236
    def memory(self):
237
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
238
239
    def axis_name_to_index(self, axis: str):
240
        """Convert an axis name to an axis index.
241
242
        Args:
243
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
244
            ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case versions
245
            and first letters are also valid, as only the first letter will be
246
            used.
247
248
        .. note:: If you are working with animals, you should probably use
249
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
250
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
251
            respectively.
252
253
        .. note:: If your images are 2D, you can use ``'Top'``, ``'Bottom'``,
254
            ``'Left'`` and ``'Right'``.
255
        """
256
        # Top and bottom are used for the vertical 2D axis as the use of
257
        # Height vs Horizontal might be ambiguous
258
259
        if not isinstance(axis, str):
260
            raise ValueError('Axis must be a string')
261
        axis = axis[0].upper()
262
263
        # Generally, TorchIO tensors are (C, W, H, D)
264
        if axis in 'TB':  # Top, Bottom
265
            return -2
266
        else:
267
            try:
268
                index = self.orientation.index(axis)
269
            except ValueError:
270
                index = self.orientation.index(self.flip_axis(axis))
271
            # Return negative indices so that it does not matter whether we
272
            # refer to spatial dimensions or not
273
            index = -3 + index
274
            return index
275
276
    # flake8: noqa: E701
277
    @staticmethod
278
    def flip_axis(axis):
279
        if axis == 'R': return 'L'
280
        elif axis == 'L': return 'R'
281
        elif axis == 'A': return 'P'
282
        elif axis == 'P': return 'A'
283
        elif axis == 'I': return 'S'
284
        elif axis == 'S': return 'I'
285
        else:
286
            values = ', '.join('LRPAISTB')
287
            message = f'Axis not understood. Please use one of: {values}'
288
            raise ValueError(message)
289
290
    def get_spacing_string(self):
291
        strings = [f'{n:.2f}' for n in self.spacing]
292
        string = f'({", ".join(strings)})'
293
        return string
294
295
    def get_bounds(self):
296
        """Get image bounds in mm."""
297
        first_index = 3 * (-0.5,)
298
        last_index = np.array(self.spatial_shape) - 0.5
299
        first_point = nib.affines.apply_affine(self.affine, first_index)
300
        last_point = nib.affines.apply_affine(self.affine, last_index)
301
        array = np.array((first_point, last_point))
302
        bounds_x, bounds_y, bounds_z = array.T.tolist()
303
        return bounds_x, bounds_y, bounds_z
304
305
    @staticmethod
306
    def _parse_single_path(
307
            path: TypePath
308
            ) -> Path:
309
        try:
310
            path = Path(path).expanduser()
311
        except TypeError:
312
            message = (
313
                f'Expected type str or Path but found {path} with '
314
                f'{type(path)} instead'
315
            )
316
            raise TypeError(message)
317
        except RuntimeError:
318
            message = (
319
                f'Conversion to path not possible for variable: {path}'
320
            )
321
            raise RuntimeError(message)
322
323
        if not (path.is_file() or path.is_dir()):   # might be a dir with DICOM
324
            raise FileNotFoundError(f'File not found: {path}')
325
        return path
326
327
    def _parse_path(
328
            self,
329
            path: Union[TypePath, Sequence[TypePath]]
330
            ) -> Union[Path, List[Path]]:
331
        if path is None:
332
            return None
333
        if isinstance(path, (str, Path)):
334
            return self._parse_single_path(path)
335
        else:
336
            return [self._parse_single_path(p) for p in path]
337
338
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
339
        if tensor is None:
340
            return None
341
        if isinstance(tensor, np.ndarray):
342
            tensor = torch.from_numpy(tensor.astype(np.float32))
343
        elif isinstance(tensor, torch.Tensor):
344
            tensor = tensor.float()
345
        if tensor.ndim != 4:
346
            raise ValueError('Input tensor must be 4D')
347
        if self.check_nans and torch.isnan(tensor).any():
348
            warnings.warn(f'NaNs found in tensor')
349
        return tensor
350
351
    def parse_tensor_shape(
352
            self,
353
            tensor: torch.Tensor,
354
            num_spatial_dims: int,
355
            ) -> torch.Tensor:
356
        return ensure_4d(tensor, num_spatial_dims)
357
358
    @staticmethod
359
    def parse_affine(affine: np.ndarray) -> np.ndarray:
360
        if affine is None:
361
            return np.eye(4)
362
        if not isinstance(affine, np.ndarray):
363
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
364
        if affine.shape != (4, 4):
365
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
366
        return affine
367
368
    def load(self) -> None:
369
        r"""Load the image from disk.
370
371
        Returns:
372
            Tuple containing a 4D tensor of size :math:`(C, W, H, D)` and a 2D
373
            :math:`4 \times 4` affine matrix to convert voxel indices to world
374
            coordinates.
375
        """
376
        if self._loaded:
377
            return
378
379
        paths = self.path if isinstance(self.path, list) else [self.path]
380
        tensor, affine = read_image(paths[0])
381
        tensor = self.parse_tensor_shape(tensor, self.num_spatial_dims)
382
383
        if self.check_nans and torch.isnan(tensor).any():
384
            warnings.warn(f'NaNs found in file "{paths[0]}"')
385
386
        tensors = [tensor]
387
        for path in paths[1:]:
388
            new_tensor, new_affine = read_image(path)
389
            new_tensor = self.parse_tensor_shape(
390
                new_tensor,
391
                self.num_spatial_dims,
392
            )
393
394
            if self.check_nans and torch.isnan(tensor).any():
395
                warnings.warn(f'NaNs found in file "{path}"')
396
397
            if not np.array_equal(affine, new_affine):
398
                message = (
399
                    'Files have different affine matrices.'
400
                    f'\nMatrix of {paths[0]}:'
401
                    f'\n{affine}'
402
                    f'\nMatrix of {path}:'
403
                    f'\n{new_affine}'
404
                )
405
                warnings.warn(message, RuntimeWarning)
406
407
            if not tensor.shape[1:] == new_tensor.shape[1:]:
408
                message = (
409
                    f'Files shape do not match, found {tensor.shape}'
410
                    f'and {new_tensor.shape}'
411
                )
412
                RuntimeError(message)
413
414
            tensors.append(new_tensor)
415
416
        tensor = torch.cat(tensors)
417
418
        self[DATA] = tensor
419
        self[AFFINE] = affine
420
        self._loaded = True
421
422
    def save(self, path, squeeze=True):
423
        """Save image to disk.
424
425
        Args:
426
            path: String or instance of :py:class:`pathlib.Path`.
427
            squeeze: If ``True``, the singleton dimensions will be removed
428
                before saving.
429
        """
430
        write_image(
431
            self[DATA],
432
            self[AFFINE],
433
            path,
434
            squeeze=squeeze,
435
        )
436
437
    def is_2d(self) -> bool:
438
        return self.shape[-1] == 1
439
440
    def numpy(self) -> np.ndarray:
441
        """Get a NumPy array containing the image data."""
442
        return np.asarray(self)
443
444
    def as_sitk(self, **kwargs) -> sitk.Image:
445
        """Get the image as an instance of :py:class:`sitk.Image`."""
446
        return nib_to_sitk(self[DATA], self[AFFINE], **kwargs)
447
448
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
449
        """Get image center in RAS+ or LPS+ coordinates.
450
451
        Args:
452
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
453
                the first dimension grows towards the left, etc. Otherwise, the
454
                coordinates will be in RAS+ orientation.
455
        """
456
        size = np.array(self.spatial_shape)
457
        center_index = (size - 1) / 2
458
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
459
        if lps:
460
            return (-r, -a, s)
461
        else:
462
            return (r, a, s)
463
464
    def set_check_nans(self, check_nans: bool):
465
        self.check_nans = check_nans
466
467
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
468
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
469
        new_affine = self.affine.copy()
470
        new_affine[:3, 3] = new_origin
471
        i0, j0, k0 = index_ini
472
        i1, j1, k1 = index_fin
473
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
474
        kwargs = dict(
475
            tensor=patch,
476
            affine=new_affine,
477
            type=self.type,
478
            path=self.path,
479
        )
480
        for key, value in self.items():
481
            if key in PROTECTED_KEYS: continue
482
            kwargs[key] = value  # should I copy? deepcopy?
483
        return self.__class__(**kwargs)
484
485
486
class ScalarImage(Image):
487
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
488
489
    Example:
490
        >>> import torch
491
        >>> import torchio
492
        >>> # Loading from a file
493
        >>> t1_image = torchio.ScalarImage('t1.nii.gz')
494
        >>> dmri = torchio.ScalarImage(tensor=torch.rand(32, 128, 128, 88))
495
        >>> image = torchio.ScalarImage('safe_image.nrrd', check_nans=False)
496
        >>> data, affine = image.data, image.affine
497
        >>> affine.shape
498
        (4, 4)
499
        >>> image.data is image[torchio.DATA]
500
        True
501
        >>> image.data is image.tensor
502
        True
503
        >>> type(image.data)
504
        torch.Tensor
505
506
    See :py:class:`~torchio.Image` for more information.
507
508
    Raises:
509
        ValueError: A :py:attr:`type` is used for instantiation.
510
    """
511
    def __init__(self, *args, **kwargs):
512
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
513
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
514
        kwargs.update({'type': INTENSITY})
515
        super().__init__(*args, **kwargs)
516
517
518
class LabelMap(Image):
519
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
520
521
    Example:
522
        >>> import torch
523
        >>> import torchio
524
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
525
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
526
        >>> tpm = torchio.LabelMap(                     # loading from files
527
        ...     'gray_matter.nii.gz',
528
        ...     'white_matter.nii.gz',
529
        ...     'csf.nii.gz',
530
        ... )
531
532
    See :py:class:`~torchio.data.image.Image` for more information.
533
534
    Raises:
535
        ValueError: If a value for :py:attr:`type` is given.
536
    """
537
    def __init__(self, *args, **kwargs):
538
        if 'type' in kwargs and kwargs['type'] != LABEL:
539
            raise ValueError('Type of LabelMap is always torchio.LABEL')
540
        kwargs.update({'type': LABEL})
541
        super().__init__(*args, **kwargs)
542