Passed
Pull Request — master (#246)
by Fernando
01:30
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 10

Size

Total Lines 41
Code Lines 35

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 10
eloc 35
nop 9
dl 0
loc 41
rs 5.9999
c 0
b 0
f 0

How to fix   Complexity    Many Parameters   

Complexity

Complex classes like torchio.data.image.Image.__init__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read. The data is expected to be 2D or
48
            3D, and may have multiple channels (see :attr:`num_spatial_dims` and
49
            :attr:`channels_last`).
50
        type: Type of image, such as :attr:`torchio.INTENSITY` or
51
            :attr:`torchio.LABEL`. This will be used by the transforms to
52
            decide whether to apply an operation, or which interpolation to use
53
            when resampling. For example, `preprocessing`_ and `augmentation`_
54
            intensity transforms will only be applied to images with type
55
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
56
            all types, and nearest neighbor interpolation is always used to
57
            resample images with type :attr:`torchio.LABEL`.
58
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
59
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
60
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
61
            :py:class:`torch.Tensor` or NumPy array with dimensions
62
            :math:`(C, D, H, W)`. If it is not 4D, TorchIO will try to guess
63
            the dimensions meanings. If 2D, the shape will be interpreted as
64
            :math:`(H, W)`. If 3D, the number of spatial dimensions should be
65
            determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims`
66
            is not given and the shape is 3 along the first or last dimensions,
67
            it will be interpreted as a multichannel 2D image. Otherwise, it
68
            be interpreted as a 3D image with a single channel.
69
        affine: If :attr:`path` is not given, :attr:`affine` must be a
70
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
71
            identity matrix.
72
        check_nans: If ``True``, issues a warning if NaNs are found
73
            in the image. If ``False``, images will not be checked for the
74
            presence of NaNs.
75
        num_spatial_dims: If ``2`` and the input tensor has 3 dimensions, it
76
            will be interpreted as a multichannel 2D image. If ``3`` and the
77
            input has 3 dimensions, it will be interpreted as a
78
            single-channel 3D volume.
79
        channels_last: If ``True``, the last dimension of the input will be
80
            interpreted as the channels. Defaults to ``True`` if :attr:`path` is
81
            given and ``False`` otherwise.
82
        **kwargs: Items that will be added to the image dictionary, e.g.
83
            acquisition parameters.
84
85
    Example:
86
        >>> import torch
87
        >>> import torchio
88
        >>> # Loading from a file
89
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
90
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
91
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
92
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
93
        >>> data, affine = image.data, image.affine
94
        >>> affine.shape
95
        (4, 4)
96
        >>> image.data is image[torchio.DATA]
97
        True
98
        >>> image.data is image.tensor
99
        True
100
        >>> type(image.data)
101
        torch.Tensor
102
103
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
104
    when needed.
105
106
    Example:
107
        >>> import torchio
108
        >>> image = torchio.Image('t1.nii.gz')
109
        >>> image  # not loaded yet
110
        Image(path: t1.nii.gz; type: intensity)
111
        >>> times_two = 2 * image.data  # data is loaded and cached here
112
        >>> image
113
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
114
        >>> image.save('doubled_image.nii.gz')
115
116
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
117
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
118
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
119
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
120
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
121
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
122
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
123
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
124
    """
125
    def __init__(
126
            self,
127
            path: Optional[TypePath] = None,
128
            type: str = INTENSITY,
129
            tensor: Optional[TypeData] = None,
130
            affine: Optional[TypeData] = None,
131
            check_nans: bool = True,
132
            num_spatial_dims: Optional[int] = None,
133
            channels_last: Optional[bool] = None,
134
            **kwargs: Dict[str, Any],
135
            ):
136
        if path is None and tensor is None:
137
            raise ValueError('A value for path or tensor must be given')
138
        self._loaded = False
139
        self.num_spatial_dims = num_spatial_dims
140
141
        # Number of channels are typically stored in the last dimensions in disk
142
        # But if a tensor is given, the channels should be in the first dim
143
        if channels_last is None:
144
            channels_last = path is not None
145
        self.channels_last = channels_last
146
147
        tensor = self.parse_tensor(tensor)
148
        affine = self.parse_affine(affine)
149
        if tensor is not None:
150
            if affine is None:
151
                affine = np.eye(4)
152
            self[DATA] = tensor
153
            self[AFFINE] = affine
154
            self._loaded = True
155
        for key in PROTECTED_KEYS:
156
            if key in kwargs:
157
                message = f'Key "{key}" is reserved. Use a different one'
158
                raise ValueError(message)
159
160
        super().__init__(**kwargs)
161
        self.path = self._parse_path(path)
162
        self[PATH] = '' if self.path is None else str(self.path)
163
        self[STEM] = '' if self.path is None else get_stem(self.path)
164
        self[TYPE] = type
165
        self.check_nans = check_nans
166
167
    def __repr__(self):
168
        properties = []
169
        if self._loaded:
170
            properties.extend([
171
                f'shape: {self.shape}',
172
                f'spacing: {self.get_spacing_string()}',
173
                f'orientation: {"".join(self.orientation)}+',
174
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
175
            ])
176
        else:
177
            properties.append(f'path: "{self.path}"')
178
        properties.append(f'type: {self.type}')
179
        properties = '; '.join(properties)
180
        string = f'{self.__class__.__name__}({properties})'
181
        return string
182
183
    def __getitem__(self, item):
184
        if item in (DATA, AFFINE):
185
            if item not in self:
186
                self._load()
187
        return super().__getitem__(item)
188
189
    def __array__(self):
190
        return self[DATA].numpy()
191
192
    @property
193
    def data(self):
194
        return self[DATA]
195
196
    @property
197
    def tensor(self):
198
        return self.data
199
200
    @property
201
    def affine(self):
202
        return self[AFFINE]
203
204
    @property
205
    def type(self):
206
        return self[TYPE]
207
208
    @property
209
    def shape(self) -> Tuple[int, int, int, int]:
210
        return tuple(self.data.shape)
211
212
    @property
213
    def spatial_shape(self) -> TypeTripletInt:
214
        return self.shape[1:]
215
216
    @property
217
    def orientation(self):
218
        return nib.aff2axcodes(self.affine)
219
220
    @property
221
    def spacing(self):
222
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
223
        return tuple(spacing)
224
225
    @property
226
    def memory(self):
227
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
228
229
    def get_spacing_string(self):
230
        strings = [f'{n:.2f}' for n in self.spacing]
231
        string = f'({", ".join(strings)})'
232
        return string
233
234
    @staticmethod
235
    def _parse_path(path: TypePath) -> Path:
236
        if path is None:
237
            return None
238
        try:
239
            path = Path(path).expanduser()
240
        except TypeError:
241
            message = f'Conversion to path not possible for variable: {path}'
242
            raise TypeError(message)
243
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
244
            raise FileNotFoundError(f'File not found: {path}')
245
        return path
246
247
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
248
        if tensor is None:
249
            return None
250
        if isinstance(tensor, np.ndarray):
251
            tensor = torch.from_numpy(tensor)
252
        tensor = self.parse_tensor_shape(tensor)
253
        return tensor
254
255
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
256
        return ensure_4d(tensor, self.channels_last, self.num_spatial_dims)
257
258
    @staticmethod
259
    def parse_affine(affine: np.ndarray) -> np.ndarray:
260
        if affine is None:
261
            return np.eye(4)
262
        if not isinstance(affine, np.ndarray):
263
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
264
        if affine.shape != (4, 4):
265
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
266
        return affine
267
268
    def _load(self) -> Tuple[torch.Tensor, np.ndarray]:
269
        r"""Load the image from disk.
270
271
        Returns:
272
            Tuple containing a 4D tensor of size :math:`(C, D, H, W)` and a 2D
273
            :math:`4 \times 4` affine matrix to convert voxel indices to world
274
            coordinates.
275
        """
276
        if self._loaded:
277
            return
278
        if self.path is None:
279
            return
280
        tensor, affine = read_image(self.path)
281
        tensor = self.parse_tensor_shape(tensor)
282
283
        if self.check_nans and torch.isnan(tensor).any():
284
            warnings.warn(f'NaNs found in file "{self.path}"')
285
        self[DATA] = tensor
286
        self[AFFINE] = affine
287
        self._loaded = True
288
289
    def save(self, path, squeeze=True, channels_last=True):
290
        """Save image to disk.
291
292
        Args:
293
            path: String or instance of :py:class:`pathlib.Path`.
294
            squeeze: If ``True``, the singleton dimensions will be removed
295
                before saving.
296
            channels_last: If ``True``, the channels will be saved in the last
297
                dimension.
298
        """
299
        write_image(
300
            self[DATA],
301
            self[AFFINE],
302
            path,
303
            squeeze=squeeze,
304
            channels_last=channels_last,
305
        )
306
307
    def is_2d(self) -> bool:
308
        return self.shape[-3] == 1
309
310
    def numpy(self) -> np.ndarray:
311
        """Get a NumPy array containing the image data."""
312
        return np.asarray(self)
313
314
    def as_sitk(self) -> sitk.Image:
315
        """Get the image as an instance of :py:class:`sitk.Image`."""
316
        return nib_to_sitk(self[DATA], self[AFFINE])
317
318
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
319
        """Get image center in RAS+ or LPS+ coordinates.
320
321
        Args:
322
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
323
                the first dimension grows towards the left, etc. Otherwise, the
324
                coordinates will be in RAS+ orientation.
325
        """
326
        size = np.array(self.spatial_shape)
327
        center_index = (size - 1) / 2
328
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
329
        if lps:
330
            return (-r, -a, s)
331
        else:
332
            return (r, a, s)
333
334
    def set_check_nans(self, check_nans):
335
        self.check_nans = check_nans
336
337
    def crop(self, index_ini, index_fin):
338
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
339
        new_affine = self.affine.copy()
340
        new_affine[:3, 3] = new_origin
341
        i0, j0, k0 = index_ini
342
        i1, j1, k1 = index_fin
343
        patch = self.data[0, i0:i1, j0:j1, k0:k1].clone()
344
        kwargs = dict(tensor=patch, affine=new_affine, type=self.type, path=self.path)
345
        for key, value in self.items():
346
            if key in PROTECTED_KEYS: continue
347
            kwargs[key] = value  # should I copy? deepcopy?
348
        return self.__class__(**kwargs)
349
350
351
class ScalarImage(Image):
352
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
353
354
    Example:
355
        >>> import torch
356
        >>> import torchio
357
        >>> image = torchio.ScalarImage('t1.nii.gz')  # loading from a file
358
        >>> image = torchio.ScalarImage(tensor=torch.rand(128, 128, 68))  # from tensor
359
        >>> data, affine = image.data, image.affine
360
        >>> affine.shape
361
        (4, 4)
362
        >>> image.data is image[torchio.DATA]
363
        True
364
        >>> image.data is image.tensor
365
        True
366
        >>> type(image.data)
367
        torch.Tensor
368
369
    See :py:class:`~torchio.Image` for more information.
370
371
    Raises:
372
        ValueError: A :py:attr:`type` is used for instantiation.
373
    """
374
    def __init__(self, *args, **kwargs):
375
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
376
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
377
        kwargs.update({'type': INTENSITY})
378
        super().__init__(*args, **kwargs)
379
380
381
class LabelMap(Image):
382
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
383
384
    Example:
385
        >>> import torch
386
        >>> import torchio
387
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
388
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
389
390
    See :py:class:`~torchio.data.image.Image` for more information.
391
392
    Raises:
393
        ValueError: If a value for :py:attr:`type` is given.
394
    """
395
    def __init__(self, *args, **kwargs):
396
        if 'type' in kwargs and kwargs['type'] != LABEL:
397
            raise ValueError('Type of LabelMap is always torchio.LABEL')
398
        kwargs.update({'type': LABEL})
399
        super().__init__(*args, **kwargs)
400