Passed
Pull Request — master (#248)
by Fernando
01:11
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 40
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 33
nop 9
dl 0
loc 40
rs 6.6666
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read. The data is expected to be 2D or
48
            3D, and may have multiple channels (see :attr:`num_spatial_dims` and
49
            :attr:`channels_last`).
50
        type: Type of image, such as :attr:`torchio.INTENSITY` or
51
            :attr:`torchio.LABEL`. This will be used by the transforms to
52
            decide whether to apply an operation, or which interpolation to use
53
            when resampling. For example, `preprocessing`_ and `augmentation`_
54
            intensity transforms will only be applied to images with type
55
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
56
            all types, and nearest neighbor interpolation is always used to
57
            resample images with type :attr:`torchio.LABEL`.
58
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
59
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
60
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
61
            :py:class:`torch.Tensor` or NumPy array with dimensions
62
            :math:`(C, D, H, W)`. If it is not 4D, TorchIO will try to guess
63
            the dimensions meanings. If 2D, the shape will be interpreted as
64
            :math:`(H, W)`. If 3D, the number of spatial dimensions should be
65
            determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims`
66
            is not given and the shape is 3 along the first or last dimensions,
67
            it will be interpreted as a multichannel 2D image. Otherwise, it
68
            be interpreted as a 3D image with a single channel.
69
        affine: If :attr:`path` is not given, :attr:`affine` must be a
70
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
71
            identity matrix.
72
        check_nans: If ``True``, issues a warning if NaNs are found
73
            in the image. If ``False``, images will not be checked for the
74
            presence of NaNs.
75
        num_spatial_dims: If ``2`` and the input tensor has 3 dimensions, it
76
            will be interpreted as a multichannel 2D image. If ``3`` and the
77
            input has 3 dimensions, it will be interpreted as a
78
            single-channel 3D volume.
79
        channels_last: If ``True``, the last dimension of the input will be
80
            interpreted as the channels. Defaults to ``True`` if :attr:`path` is
81
            given and ``False`` otherwise.
82
        **kwargs: Items that will be added to the image dictionary, e.g.
83
            acquisition parameters.
84
85
    Example:
86
        >>> import torch
87
        >>> import torchio
88
        >>> # Loading from a file
89
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
90
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
91
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
92
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
93
        >>> data, affine = image.data, image.affine
94
        >>> affine.shape
95
        (4, 4)
96
        >>> image.data is image[torchio.DATA]
97
        True
98
        >>> image.data is image.tensor
99
        True
100
        >>> type(image.data)
101
        torch.Tensor
102
103
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
104
    when needed.
105
106
    Example:
107
        >>> import torchio
108
        >>> image = torchio.Image('t1.nii.gz')
109
        >>> image  # not loaded yet
110
        Image(path: t1.nii.gz; type: intensity)
111
        >>> times_two = 2 * image.data  # data is loaded and cached here
112
        >>> image
113
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
114
        >>> image.save('doubled_image.nii.gz')
115
116
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
117
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
118
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
119
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
120
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
121
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
122
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
123
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
124
    """
125
    def __init__(
126
            self,
127
            path: Optional[TypePath] = None,
128
            type: str = INTENSITY,
129
            tensor: Optional[TypeData] = None,
130
            affine: Optional[TypeData] = None,
131
            check_nans: bool = True,
132
            num_spatial_dims: Optional[int] = None,
133
            channels_last: Optional[bool] = None,
134
            **kwargs: Dict[str, Any],
135
            ):
136
        self.check_nans = check_nans
137
        self.num_spatial_dims = num_spatial_dims
138
139
        if path is None and tensor is None:
140
            raise ValueError('A value for path or tensor must be given')
141
        self._loaded = False
142
143
        # Number of channels are typically stored in the last dimensions in disk
144
        # But if a tensor is given, the channels should be in the first dim
145
        if channels_last is None:
146
            channels_last = path is not None
147
        self.channels_last = channels_last
148
149
        tensor = self.parse_tensor(tensor)
150
        affine = self.parse_affine(affine)
151
        if tensor is not None:
152
            self[DATA] = tensor
153
            self[AFFINE] = affine
154
            self._loaded = True
155
        for key in PROTECTED_KEYS:
156
            if key in kwargs:
157
                message = f'Key "{key}" is reserved. Use a different one'
158
                raise ValueError(message)
159
160
        super().__init__(**kwargs)
161
        self.path = self._parse_path(path)
162
        self[PATH] = '' if self.path is None else str(self.path)
163
        self[STEM] = '' if self.path is None else get_stem(self.path)
164
        self[TYPE] = type
165
166
    def __repr__(self):
167
        properties = []
168
        if self._loaded:
169
            properties.extend([
170
                f'shape: {self.shape}',
171
                f'spacing: {self.get_spacing_string()}',
172
                f'orientation: {"".join(self.orientation)}+',
173
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
174
            ])
175
        else:
176
            properties.append(f'path: "{self.path}"')
177
        properties.append(f'type: {self.type}')
178
        properties = '; '.join(properties)
179
        string = f'{self.__class__.__name__}({properties})'
180
        return string
181
182
    def __getitem__(self, item):
183
        if item in (DATA, AFFINE):
184
            if item not in self:
185
                self._load()
186
        return super().__getitem__(item)
187
188
    def __array__(self):
189
        return self[DATA].numpy()
190
191
    def __copy__(self):
192
        kwargs = dict(
193
            tensor=self.data,
194
            affine=self.affine,
195
            type=self.type,
196
            path=self.path,
197
            channels_last=False,
198
        )
199
        for key, value in self.items():
200
            if key in PROTECTED_KEYS: continue
201
            kwargs[key] = value  # should I copy? deepcopy?
202
        return self.__class__(**kwargs)
203
204
    @property
205
    def data(self):
206
        return self[DATA]
207
208
    @property
209
    def tensor(self):
210
        return self.data
211
212
    @property
213
    def affine(self):
214
        return self[AFFINE]
215
216
    @property
217
    def type(self):
218
        return self[TYPE]
219
220
    @property
221
    def shape(self) -> Tuple[int, int, int, int]:
222
        return tuple(self.data.shape)
223
224
    @property
225
    def spatial_shape(self) -> TypeTripletInt:
226
        return self.shape[1:]
227
228
    @property
229
    def orientation(self):
230
        return nib.aff2axcodes(self.affine)
231
232
    @property
233
    def spacing(self):
234
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
235
        return tuple(spacing)
236
237
    @property
238
    def memory(self):
239
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
240
241
    def get_spacing_string(self):
242
        strings = [f'{n:.2f}' for n in self.spacing]
243
        string = f'({", ".join(strings)})'
244
        return string
245
246
    @staticmethod
247
    def _parse_path(path: TypePath) -> Path:
248
        if path is None:
249
            return None
250
        try:
251
            path = Path(path).expanduser()
252
        except TypeError:
253
            message = f'Conversion to path not possible for variable: {path}'
254
            raise TypeError(message)
255
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
256
            raise FileNotFoundError(f'File not found: {path}')
257
        return path
258
259
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
260
        if tensor is None:
261
            return None
262
        if isinstance(tensor, np.ndarray):
263
            tensor = torch.from_numpy(tensor)
264
        tensor = self.parse_tensor_shape(tensor)
265
        if self.check_nans and torch.isnan(tensor).any():
266
            warnings.warn(f'NaNs found in tensor')
267
        return tensor
268
269
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
270
        return ensure_4d(tensor, self.channels_last, self.num_spatial_dims)
271
272
    @staticmethod
273
    def parse_affine(affine: np.ndarray) -> np.ndarray:
274
        if affine is None:
275
            return np.eye(4)
276
        if not isinstance(affine, np.ndarray):
277
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
278
        if affine.shape != (4, 4):
279
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
280
        return affine
281
282
    def _load(self) -> Tuple[torch.Tensor, np.ndarray]:
283
        r"""Load the image from disk.
284
285
        Returns:
286
            Tuple containing a 4D tensor of size :math:`(C, D, H, W)` and a 2D
287
            :math:`4 \times 4` affine matrix to convert voxel indices to world
288
            coordinates.
289
        """
290
        if self._loaded:
291
            return
292
        tensor, affine = read_image(self.path)
293
        tensor = self.parse_tensor_shape(tensor)
294
295
        if self.check_nans and torch.isnan(tensor).any():
296
            warnings.warn(f'NaNs found in file "{self.path}"')
297
        self[DATA] = tensor
298
        self[AFFINE] = affine
299
        self._loaded = True
300
301
    def save(self, path, squeeze=True, channels_last=True):
302
        """Save image to disk.
303
304
        Args:
305
            path: String or instance of :py:class:`pathlib.Path`.
306
            squeeze: If ``True``, the singleton dimensions will be removed
307
                before saving.
308
            channels_last: If ``True``, the channels will be saved in the last
309
                dimension.
310
        """
311
        write_image(
312
            self[DATA],
313
            self[AFFINE],
314
            path,
315
            squeeze=squeeze,
316
            channels_last=channels_last,
317
        )
318
319
    def is_2d(self) -> bool:
320
        return self.shape[-3] == 1
321
322
    def numpy(self) -> np.ndarray:
323
        """Get a NumPy array containing the image data."""
324
        return np.asarray(self)
325
326
    def as_sitk(self) -> sitk.Image:
327
        """Get the image as an instance of :py:class:`sitk.Image`."""
328
        return nib_to_sitk(self[DATA], self[AFFINE])
329
330
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
331
        """Get image center in RAS+ or LPS+ coordinates.
332
333
        Args:
334
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
335
                the first dimension grows towards the left, etc. Otherwise, the
336
                coordinates will be in RAS+ orientation.
337
        """
338
        size = np.array(self.spatial_shape)
339
        center_index = (size - 1) / 2
340
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
341
        if lps:
342
            return (-r, -a, s)
343
        else:
344
            return (r, a, s)
345
346
    def set_check_nans(self, check_nans: bool):
347
        self.check_nans = check_nans
348
349
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
350
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
351
        new_affine = self.affine.copy()
352
        new_affine[:3, 3] = new_origin
353
        i0, j0, k0 = index_ini
354
        i1, j1, k1 = index_fin
355
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
356
        kwargs = dict(
357
            tensor=patch,
358
            affine=new_affine,
359
            type=self.type,
360
            path=self.path,
361
            channels_last=False,
362
        )
363
        for key, value in self.items():
364
            if key in PROTECTED_KEYS: continue
365
            kwargs[key] = value  # should I copy? deepcopy?
366
        return self.__class__(**kwargs)
367
368
369
class ScalarImage(Image):
370
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
371
372
    Example:
373
        >>> import torch
374
        >>> import torchio
375
        >>> image = torchio.ScalarImage('t1.nii.gz')  # loading from a file
376
        >>> image = torchio.ScalarImage(tensor=torch.rand(128, 128, 68))  # from tensor
377
        >>> data, affine = image.data, image.affine
378
        >>> affine.shape
379
        (4, 4)
380
        >>> image.data is image[torchio.DATA]
381
        True
382
        >>> image.data is image.tensor
383
        True
384
        >>> type(image.data)
385
        torch.Tensor
386
387
    See :py:class:`~torchio.Image` for more information.
388
389
    Raises:
390
        ValueError: A :py:attr:`type` is used for instantiation.
391
    """
392
    def __init__(self, *args, **kwargs):
393
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
394
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
395
        kwargs.update({'type': INTENSITY})
396
        super().__init__(*args, **kwargs)
397
398
399
class LabelMap(Image):
400
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
401
402
    Example:
403
        >>> import torch
404
        >>> import torchio
405
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
406
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
407
408
    See :py:class:`~torchio.data.image.Image` for more information.
409
410
    Raises:
411
        ValueError: If a value for :py:attr:`type` is given.
412
    """
413
    def __init__(self, *args, **kwargs):
414
        if 'type' in kwargs and kwargs['type'] != LABEL:
415
            raise ValueError('Type of LabelMap is always torchio.LABEL')
416
        kwargs.update({'type': LABEL})
417
        super().__init__(*args, **kwargs)
418