Passed
Push — master ( 0a9301...f0d368 )
by Fernando
01:26
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 10

Size

Total Lines 48
Code Lines 37

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 10
eloc 37
nop 9
dl 0
loc 48
rs 5.9999
c 0
b 0
f 0

How to fix   Complexity    Many Parameters   

Complexity

Complex classes like torchio.data.image.Image.__init__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read. The data is expected to be 2D or
48
            3D, and may have multiple channels (see :attr:`num_spatial_dims` and
49
            :attr:`channels_last`).
50
        type: Type of image, such as :attr:`torchio.INTENSITY` or
51
            :attr:`torchio.LABEL`. This will be used by the transforms to
52
            decide whether to apply an operation, or which interpolation to use
53
            when resampling. For example, `preprocessing`_ and `augmentation`_
54
            intensity transforms will only be applied to images with type
55
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
56
            all types, and nearest neighbor interpolation is always used to
57
            resample images with type :attr:`torchio.LABEL`.
58
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
59
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
60
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
61
            :py:class:`torch.Tensor` or NumPy array with dimensions
62
            :math:`(C, D, H, W)`. If it is not 4D, TorchIO will try to guess
63
            the dimensions meanings. If 2D, the shape will be interpreted as
64
            :math:`(H, W)`. If 3D, the number of spatial dimensions should be
65
            determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims`
66
            is not given and the shape is 3 along the first or last dimensions,
67
            it will be interpreted as a multichannel 2D image. Otherwise, it
68
            be interpreted as a 3D image with a single channel.
69
        affine: If :attr:`path` is not given, :attr:`affine` must be a
70
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
71
            identity matrix.
72
        check_nans: If ``True``, issues a warning if NaNs are found
73
            in the image. If ``False``, images will not be checked for the
74
            presence of NaNs.
75
        num_spatial_dims: If ``2`` and the input tensor has 3 dimensions, it
76
            will be interpreted as a multichannel 2D image. If ``3`` and the
77
            input has 3 dimensions, it will be interpreted as a
78
            single-channel 3D volume.
79
        channels_last: If ``True``, the last dimension of the input will be
80
            interpreted as the channels. Defaults to ``True`` if :attr:`path` is
81
            given and ``False`` otherwise.
82
        **kwargs: Items that will be added to the image dictionary, e.g.
83
            acquisition parameters.
84
85
    Example:
86
        >>> import torch
87
        >>> import torchio
88
        >>> # Loading from a file
89
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
90
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
91
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
92
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
93
        >>> data, affine = image.data, image.affine
94
        >>> affine.shape
95
        (4, 4)
96
        >>> image.data is image[torchio.DATA]
97
        True
98
        >>> image.data is image.tensor
99
        True
100
        >>> type(image.data)
101
        torch.Tensor
102
103
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
104
    when needed.
105
106
    Example:
107
        >>> import torchio
108
        >>> image = torchio.Image('t1.nii.gz')
109
        >>> image  # not loaded yet
110
        Image(path: t1.nii.gz; type: intensity)
111
        >>> times_two = 2 * image.data  # data is loaded and cached here
112
        >>> image
113
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
114
        >>> image.save('doubled_image.nii.gz')
115
116
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
117
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
118
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
119
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
120
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
121
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
122
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
123
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
124
    """
125
    def __init__(
126
            self,
127
            path: Optional[TypePath] = None,
128
            type: str = None,
129
            tensor: Optional[TypeData] = None,
130
            affine: Optional[TypeData] = None,
131
            check_nans: bool = True,
132
            num_spatial_dims: Optional[int] = None,
133
            channels_last: Optional[bool] = None,
134
            **kwargs: Dict[str, Any],
135
            ):
136
        self.check_nans = check_nans
137
        self.num_spatial_dims = num_spatial_dims
138
139
        if type is None:
140
            warnings.warn(
141
                'Not specifying the image type is deprecated and will be'
142
                ' mandatory in the future. You can probably use ScalarImage or'
143
                ' LabelMap instead'
144
            )
145
            type = INTENSITY
146
147
        if path is None and tensor is None:
148
            raise ValueError('A value for path or tensor must be given')
149
        self._loaded = False
150
151
        # Number of channels are typically stored in the last dimensions in disk
152
        # But if a tensor is given, the channels should be in the first dim
153
        if channels_last is None:
154
            channels_last = path is not None
155
        self.channels_last = channels_last
156
157
        tensor = self.parse_tensor(tensor)
158
        affine = self.parse_affine(affine)
159
        if tensor is not None:
160
            self[DATA] = tensor
161
            self[AFFINE] = affine
162
            self._loaded = True
163
        for key in PROTECTED_KEYS:
164
            if key in kwargs:
165
                message = f'Key "{key}" is reserved. Use a different one'
166
                raise ValueError(message)
167
168
        super().__init__(**kwargs)
169
        self.path = self._parse_path(path)
170
        self[PATH] = '' if self.path is None else str(self.path)
171
        self[STEM] = '' if self.path is None else get_stem(self.path)
172
        self[TYPE] = type
173
174
    def __repr__(self):
175
        properties = []
176
        if self._loaded:
177
            properties.extend([
178
                f'shape: {self.shape}',
179
                f'spacing: {self.get_spacing_string()}',
180
                f'orientation: {"".join(self.orientation)}+',
181
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
182
            ])
183
        else:
184
            properties.append(f'path: "{self.path}"')
185
        properties.append(f'type: {self.type}')
186
        properties = '; '.join(properties)
187
        string = f'{self.__class__.__name__}({properties})'
188
        return string
189
190
    def __getitem__(self, item):
191
        if item in (DATA, AFFINE):
192
            if item not in self:
193
                self._load()
194
        return super().__getitem__(item)
195
196
    def __array__(self):
197
        return self[DATA].numpy()
198
199
    def __copy__(self):
200
        kwargs = dict(
201
            tensor=self.data,
202
            affine=self.affine,
203
            type=self.type,
204
            path=self.path,
205
            channels_last=False,
206
        )
207
        for key, value in self.items():
208
            if key in PROTECTED_KEYS: continue
209
            kwargs[key] = value  # should I copy? deepcopy?
210
        return self.__class__(**kwargs)
211
212
    @property
213
    def data(self):
214
        return self[DATA]
215
216
    @property
217
    def tensor(self):
218
        return self.data
219
220
    @property
221
    def affine(self):
222
        return self[AFFINE]
223
224
    @property
225
    def type(self):
226
        return self[TYPE]
227
228
    @property
229
    def shape(self) -> Tuple[int, int, int, int]:
230
        return tuple(self.data.shape)
231
232
    @property
233
    def spatial_shape(self) -> TypeTripletInt:
234
        return self.shape[1:]
235
236
    @property
237
    def orientation(self):
238
        return nib.aff2axcodes(self.affine)
239
240
    @property
241
    def spacing(self):
242
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
243
        return tuple(spacing)
244
245
    @property
246
    def memory(self):
247
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
248
249
    def axis_name_to_index(self, axis: str):
250
        """Convert an axis name to an axis index.
251
252
        Args:
253
            axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
254
            ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case versions
255
            and first letters are also valid, as only the first letter will be
256
            used.
257
258
        .. note:: If you are working with animals, you should probably use
259
            ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
260
            for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
261
            respectively.
262
        """
263
        if not isinstance(axis, str):
264
            raise ValueError('Axis must be a string')
265
        axis = axis[0].upper()
266
267
        # Generally, TorchIO tensors are (C, D, H, W)
268
        if axis == 'H':
269
            return -2
270
        elif axis == 'W':
271
            return -1
272
        else:
273
            try:
274
                index = self.orientation.index(axis)
275
            except ValueError:
276
                index = self.orientation.index(self.flip_axis(axis))
277
            # Return negative indices so that it does not matter whether we
278
            # refer to spatial dimensions or not
279
            index = -4 + index
280
            return index
281
282
    @staticmethod
283
    def flip_axis(axis):
284
        if axis == 'R': return 'L'
285
        elif axis == 'L': return 'R'
286
        elif axis == 'A': return 'P'
287
        elif axis == 'P': return 'A'
288
        elif axis == 'I': return 'S'
289
        elif axis == 'S': return 'I'
290
        else:
291
            message = (
292
                f'Axis not understood. Please use a value in {tuple("LRAPIS")}'
293
            )
294
            raise ValueError(message)
295
296
    def get_spacing_string(self):
297
        strings = [f'{n:.2f}' for n in self.spacing]
298
        string = f'({", ".join(strings)})'
299
        return string
300
301
    def get_bounds(self):
302
        """Get image bounds in mm."""
303
        first_index = 3 * (-0.5,)
304
        last_index = np.array(self.spatial_shape) - 0.5
305
        first_point = nib.affines.apply_affine(self.affine, first_index)
306
        last_point = nib.affines.apply_affine(self.affine, last_index)
307
        array = np.array((first_point, last_point))
308
        bounds_x, bounds_y, bounds_z = array.T.tolist()
309
        return bounds_x, bounds_y, bounds_z
310
311
    @staticmethod
312
    def _parse_path(path: TypePath) -> Path:
313
        if path is None:
314
            return None
315
        try:
316
            path = Path(path).expanduser()
317
        except TypeError:
318
            message = f'Conversion to path not possible for variable: {path}'
319
            raise TypeError(message)
320
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
321
            raise FileNotFoundError(f'File not found: {path}')
322
        return path
323
324
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
325
        if tensor is None:
326
            return None
327
        if isinstance(tensor, np.ndarray):
328
            tensor = torch.from_numpy(tensor.astype(np.float32))
329
        elif isinstance(tensor, torch.Tensor):
330
            tensor = tensor.float()
331
        tensor = self.parse_tensor_shape(tensor)
332
        if self.check_nans and torch.isnan(tensor).any():
333
            warnings.warn(f'NaNs found in tensor')
334
        return tensor
335
336
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
337
        return ensure_4d(tensor, self.channels_last, self.num_spatial_dims)
338
339
    @staticmethod
340
    def parse_affine(affine: np.ndarray) -> np.ndarray:
341
        if affine is None:
342
            return np.eye(4)
343
        if not isinstance(affine, np.ndarray):
344
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
345
        if affine.shape != (4, 4):
346
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
347
        return affine
348
349
    def _load(self) -> Tuple[torch.Tensor, np.ndarray]:
350
        r"""Load the image from disk.
351
352
        Returns:
353
            Tuple containing a 4D tensor of size :math:`(C, D, H, W)` and a 2D
354
            :math:`4 \times 4` affine matrix to convert voxel indices to world
355
            coordinates.
356
        """
357
        if self._loaded:
358
            return
359
        tensor, affine = read_image(self.path)
360
        tensor = self.parse_tensor_shape(tensor)
361
362
        if self.check_nans and torch.isnan(tensor).any():
363
            warnings.warn(f'NaNs found in file "{self.path}"')
364
        self[DATA] = tensor
365
        self[AFFINE] = affine
366
        self._loaded = True
367
368
    def save(self, path, squeeze=True, channels_last=True):
369
        """Save image to disk.
370
371
        Args:
372
            path: String or instance of :py:class:`pathlib.Path`.
373
            squeeze: If ``True``, the singleton dimensions will be removed
374
                before saving.
375
            channels_last: If ``True``, the channels will be saved in the last
376
                dimension.
377
        """
378
        write_image(
379
            self[DATA],
380
            self[AFFINE],
381
            path,
382
            squeeze=squeeze,
383
            channels_last=channels_last,
384
        )
385
386
    def is_2d(self) -> bool:
387
        return self.shape[-3] == 1
388
389
    def numpy(self) -> np.ndarray:
390
        """Get a NumPy array containing the image data."""
391
        return np.asarray(self)
392
393
    def as_sitk(self) -> sitk.Image:
394
        """Get the image as an instance of :py:class:`sitk.Image`."""
395
        return nib_to_sitk(self[DATA], self[AFFINE])
396
397
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
398
        """Get image center in RAS+ or LPS+ coordinates.
399
400
        Args:
401
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
402
                the first dimension grows towards the left, etc. Otherwise, the
403
                coordinates will be in RAS+ orientation.
404
        """
405
        size = np.array(self.spatial_shape)
406
        center_index = (size - 1) / 2
407
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
408
        if lps:
409
            return (-r, -a, s)
410
        else:
411
            return (r, a, s)
412
413
    def set_check_nans(self, check_nans: bool):
414
        self.check_nans = check_nans
415
416
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
417
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
418
        new_affine = self.affine.copy()
419
        new_affine[:3, 3] = new_origin
420
        i0, j0, k0 = index_ini
421
        i1, j1, k1 = index_fin
422
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
423
        kwargs = dict(
424
            tensor=patch,
425
            affine=new_affine,
426
            type=self.type,
427
            path=self.path,
428
            channels_last=False,
429
        )
430
        for key, value in self.items():
431
            if key in PROTECTED_KEYS: continue
432
            kwargs[key] = value  # should I copy? deepcopy?
433
        return self.__class__(**kwargs)
434
435
436
class ScalarImage(Image):
437
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
438
439
    Example:
440
        >>> import torch
441
        >>> import torchio
442
        >>> image = torchio.ScalarImage('t1.nii.gz')  # loading from a file
443
        >>> image = torchio.ScalarImage(tensor=torch.rand(128, 128, 68))  # from tensor
444
        >>> data, affine = image.data, image.affine
445
        >>> affine.shape
446
        (4, 4)
447
        >>> image.data is image[torchio.DATA]
448
        True
449
        >>> image.data is image.tensor
450
        True
451
        >>> type(image.data)
452
        torch.Tensor
453
454
    See :py:class:`~torchio.Image` for more information.
455
456
    Raises:
457
        ValueError: A :py:attr:`type` is used for instantiation.
458
    """
459
    def __init__(self, *args, **kwargs):
460
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
461
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
462
        kwargs.update({'type': INTENSITY})
463
        super().__init__(*args, **kwargs)
464
465
466
class LabelMap(Image):
467
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
468
469
    Example:
470
        >>> import torch
471
        >>> import torchio
472
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
473
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
474
475
    See :py:class:`~torchio.data.image.Image` for more information.
476
477
    Raises:
478
        ValueError: If a value for :py:attr:`type` is given.
479
    """
480
    def __init__(self, *args, **kwargs):
481
        if 'type' in kwargs and kwargs['type'] != LABEL:
482
            raise ValueError('Type of LabelMap is always torchio.LABEL')
483
        kwargs.update({'type': LABEL})
484
        super().__init__(*args, **kwargs)
485