Passed
Pull Request — master (#246)
by Fernando
01:10
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 35
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 33
nop 9
dl 0
loc 35
rs 6.6666
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
class Image(dict):
34
35
    PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
36
37
    r"""TorchIO image.
38
39
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
40
    when needed.
41
42
    Example:
43
        >>> import torchio
44
        >>> image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
45
        >>> image  # not loaded yet
46
        Image(path: t1.nii.gz; type: intensity)
47
        >>> times_two = 2 * image.data  # data is loaded and cached here
48
        >>> image
49
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
50
        >>> image.save('doubled_image.nii.gz')
51
52
    For information about medical image orientation, check out `NiBabel docs`_,
53
    the `3D Slicer wiki`_, `Graham Wideman's website`_ or `FSL docs`_.
54
55
    Args:
56
        path: Path to a file that can be read by
57
            :mod:`SimpleITK` or :mod:`nibabel` or to a directory containing
58
            DICOM files. If :py:attr:`tensor` is given, the data in
59
            :py:attr:`path` will not be read.    # TODO: document dims
60
        type: Type of image, such as :attr:`torchio.INTENSITY` or
61
            :attr:`torchio.LABEL`. This will be used by the transforms to
62
            decide whether to apply an operation, or which interpolation to use
63
            when resampling. For example, `preprocessing`_ and `augmentation`_
64
            intensity transforms will only be applied to images with type
65
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
66
            all types, and nearest neighbor interpolation is always used to
67
            resample images with type :attr:`torchio.LABEL`.
68
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
69
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
70
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 3D
71
            :py:class:`torch.Tensor` or NumPy array with dimensions
72
            :math:`(D, H, W)`.   # TODO: document
73
        affine: If :attr:`path` is not given, :attr:`affine` must be a
74
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
75
            identity matrix.
76
        check_nans: If ``True``, issues a warning if NaNs are found
77
            in the image. If ``False``, images will not be checked for the
78
            presence of NaNs.
79
        **kwargs: Items that will be added to image dictionary within the
80
            subject sample.
81
82
    Example:
83
        >>> import torch
84
        >>> import torchio
85
        >>> # Loading from a file
86
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
87
        >>> # Also:
88
        >>> image = torchio.ScalarImage('t1.nii.gz')
89
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
90
        >>> # Also:
91
        >>> label_image = torchio.LabelMap('t1_seg.nii.gz')
92
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
93
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
94
        >>> data, affine = image.data, image.affine
95
        >>> affine.shape
96
        (4, 4)
97
        >>> image.data is image[torchio.DATA]
98
        True
99
        >>> image.data is image.tensor
100
        True
101
        >>> type(image.data)
102
        torch.Tensor
103
104
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
105
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
106
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
107
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
108
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
109
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
110
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
111
112
    """
113
    def __init__(
114
            self,
115
            path: Optional[TypePath] = None,
116
            type: str = INTENSITY,
117
            tensor: Optional[TypeData] = None,
118
            affine: Optional[TypeData] = None,
119
            check_nans: bool = True,
120
            num_spatial_dims: Optional[int] = None,  # TODO: document
121
            channels_last: bool = True,  # TODO: document
122
            **kwargs: Dict[str, Any],
123
            ):
124
        if path is None and tensor is None:
125
            raise ValueError('A value for path or tensor must be given')
126
        self._loaded = False
127
        self.num_spatial_dims = num_spatial_dims
128
        self.channels_last = channels_last
129
        tensor = self.parse_tensor(tensor)
130
        affine = self.parse_affine(affine)
131
        if tensor is not None:
132
            if affine is None:
133
                affine = np.eye(4)
134
            self[DATA] = tensor
135
            self[AFFINE] = affine
136
            self._loaded = True
137
        for key in self.PROTECTED_KEYS:
138
            if key in kwargs:
139
                message = f'Key "{key}" is reserved. Use a different one'
140
                raise ValueError(message)
141
142
        super().__init__(**kwargs)
143
        self.path = self._parse_path(path)
144
        self[PATH] = '' if self.path is None else str(self.path)
145
        self[STEM] = '' if self.path is None else get_stem(self.path)
146
        self[TYPE] = type
147
        self.check_nans = check_nans
148
149
    def __repr__(self):
150
        properties = []
151
        if self._loaded:
152
            properties.extend([
153
                f'shape: {self.shape}',
154
                f'spacing: {self.get_spacing_string()}',
155
                f'orientation: {"".join(self.orientation)}+',
156
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
157
            ])
158
        else:
159
            properties.append(f'path: "{self.path}"')
160
        properties.append(f'type: {self.type}')
161
        properties = '; '.join(properties)
162
        string = f'{self.__class__.__name__}({properties})'
163
        return string
164
165
    def __getitem__(self, item):
166
        if item in (DATA, AFFINE):
167
            if item not in self:
168
                self.load()
169
        return super().__getitem__(item)
170
171
    @property
172
    def data(self):
173
        return self[DATA]
174
175
    @property
176
    def tensor(self):
177
        return self.data
178
179
    @property
180
    def affine(self):
181
        return self[AFFINE]
182
183
    @property
184
    def type(self):
185
        return self[TYPE]
186
187
    @property
188
    def shape(self) -> Tuple[int, int, int, int]:
189
        return tuple(self.data.shape)
190
191
    @property
192
    def spatial_shape(self) -> TypeTripletInt:
193
        return self.shape[1:]
194
195
    @property
196
    def orientation(self):
197
        return nib.aff2axcodes(self.affine)
198
199
    @property
200
    def spacing(self):
201
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
202
        return tuple(spacing)
203
204
    @property
205
    def memory(self):
206
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
207
208
    def get_spacing_string(self):
209
        strings = [f'{n:.2f}' for n in self.spacing]
210
        string = f'({", ".join(strings)})'
211
        return string
212
213
    @staticmethod
214
    def _parse_path(path: TypePath) -> Path:
215
        if path is None:
216
            return None
217
        try:
218
            path = Path(path).expanduser()
219
        except TypeError:
220
            message = f'Conversion to path not possible for variable: {path}'
221
            raise TypeError(message)
222
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
223
            raise FileNotFoundError(f'File not found: {path}')
224
        return path
225
226
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
227
        if tensor is None:
228
            return None
229
        if isinstance(tensor, np.ndarray):
230
            tensor = torch.from_numpy(tensor)
231
        tensor = self.parse_tensor_shape(tensor)
232
        return tensor
233
234
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
235
        return ensure_4d(tensor, self.channels_last, self.num_spatial_dims)
236
237
    @staticmethod
238
    def parse_affine(affine: np.ndarray) -> np.ndarray:
239
        if affine is None:
240
            return np.eye(4)
241
        if not isinstance(affine, np.ndarray):
242
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
243
        if affine.shape != (4, 4):
244
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
245
        return affine
246
247
    def load(self) -> Tuple[torch.Tensor, np.ndarray]:
248
        r"""Load the image from disk.
249
250
        The file is expected to be monomodal/grayscale and 2D or 3D.
251
        A channels dimension is added to the tensor.
252
253
        Returns:
254
            Tuple containing a 4D data tensor of size
255
            :math:`(1, D_{in}, H_{in}, W_{in})`
256
            and a 2D 4x4 affine matrix
257
        """
258
        if self._loaded:
259
            return
260
        if self.path is None:
261
            return
262
        tensor, affine = read_image(self.path)
263
        tensor = self.parse_tensor_shape(tensor)
264
265
        if self.check_nans and torch.isnan(tensor).any():
266
            warnings.warn(f'NaNs found in file "{self.path}"')
267
        self[DATA] = tensor
268
        self[AFFINE] = affine
269
        self._loaded = True
270
271
    def save(self, path):
272
        """Save image to disk.
273
274
        Args:
275
            path: String or instance of :py:class:`pathlib.Path`.
276
        """
277
        tensor = self[DATA]
278
        if self.channels_last:
279
            tensor = tensor.permute(1, 2, 3, 0)
280
        tensor = tensor.squeeze()
281
        affine = self[AFFINE]
282
        write_image(tensor, affine, path)
283
284
    def is_2d(self) -> bool:
285
        return self.shape[-3] == 1
286
287
    def numpy(self) -> np.ndarray:
288
        """Get a NumPy array containing the image data."""
289
        return self[DATA].numpy()
290
291
    def as_sitk(self) -> sitk.Image:
292
        """Get the image as an instance of :py:class:`sitk.Image`."""
293
        return nib_to_sitk(self[DATA], self[AFFINE])
294
295
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
296
        """Get image center in RAS+ or LPS+ coordinates.
297
298
        Args:
299
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
300
                the first dimension grows towards the left, etc. Otherwise, the
301
                coordinates will be in RAS+ orientation.
302
        """
303
        size = np.array(self.spatial_shape)
304
        center_index = (size - 1) / 2
305
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
306
        if lps:
307
            return (-r, -a, s)
308
        else:
309
            return (r, a, s)
310
311
    def set_check_nans(self, check_nans):
312
        self.check_nans = check_nans
313
314
    def crop(self, index_ini, index_fin):
315
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
316
        new_affine = self.affine.copy()
317
        new_affine[:3, 3] = new_origin
318
        i0, j0, k0 = index_ini
319
        i1, j1, k1 = index_fin
320
        patch = self.data[0, i0:i1, j0:j1, k0:k1].clone()
321
        kwargs = dict(tensor=patch, affine=new_affine, type=self.type, path=self.path)
322
        for key, value in self.items():
323
            if key in self.PROTECTED_KEYS: continue
324
            kwargs[key] = value  # should I copy? deepcopy?
325
        return self.__class__(**kwargs)
326
327
328
class ScalarImage(Image):
329
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
330
331
    See :py:class:`~torchio.Image` for more information.
332
333
    Raises:
334
        ValueError: A :py:attr:`type` is used for instantiation.
335
    """
336
    def __init__(self, *args, **kwargs):
337
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
338
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
339
        kwargs.update({'type': INTENSITY})
340
        super().__init__(*args, **kwargs)
341
342
343
class LabelMap(Image):
344
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
345
346
    See :py:class:`~torchio.Image` for more information.
347
348
    Raises:
349
        ValueError: A :py:attr:`type` is used for instantiation.
350
    """
351
    def __init__(self, *args, **kwargs):
352
        if 'type' in kwargs and kwargs['type'] != LABEL:
353
            raise ValueError('Type of LabelMap is always torchio.LABEL')
354
        kwargs.update({'type': LABEL})
355
        super().__init__(*args, **kwargs)
356