Passed
Pull Request — master (#246)
by Fernando
59s
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 35
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 33
nop 9
dl 0
loc 35
rs 6.6666
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel` or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read.    # TODO: document dims
48
        type: Type of image, such as :attr:`torchio.INTENSITY` or
49
            :attr:`torchio.LABEL`. This will be used by the transforms to
50
            decide whether to apply an operation, or which interpolation to use
51
            when resampling. For example, `preprocessing`_ and `augmentation`_
52
            intensity transforms will only be applied to images with type
53
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
54
            all types, and nearest neighbor interpolation is always used to
55
            resample images with type :attr:`torchio.LABEL`.
56
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
57
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
58
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 3D
59
            :py:class:`torch.Tensor` or NumPy array with dimensions
60
            :math:`(D, H, W)`.   # TODO: document
61
        affine: If :attr:`path` is not given, :attr:`affine` must be a
62
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
63
            identity matrix.
64
        check_nans: If ``True``, issues a warning if NaNs are found
65
            in the image. If ``False``, images will not be checked for the
66
            presence of NaNs.
67
        **kwargs: Items that will be added to image dictionary within the
68
            subject sample.
69
70
    Example:
71
        >>> import torch
72
        >>> import torchio
73
        >>> # Loading from a file
74
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
75
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
76
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
77
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
78
        >>> data, affine = image.data, image.affine
79
        >>> affine.shape
80
        (4, 4)
81
        >>> image.data is image[torchio.DATA]
82
        True
83
        >>> image.data is image.tensor
84
        True
85
        >>> type(image.data)
86
        torch.Tensor
87
88
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
89
    when needed.
90
91
    Example:
92
        >>> import torchio
93
        >>> image = torchio.Image('t1.nii.gz')
94
        >>> image  # not loaded yet
95
        Image(path: t1.nii.gz; type: intensity)
96
        >>> times_two = 2 * image.data  # data is loaded and cached here
97
        >>> image
98
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
99
        >>> image.save('doubled_image.nii.gz')
100
101
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
102
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
103
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
104
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
105
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
106
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
107
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
108
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
109
    """
110
    def __init__(
111
            self,
112
            path: Optional[TypePath] = None,
113
            type: str = INTENSITY,
114
            tensor: Optional[TypeData] = None,
115
            affine: Optional[TypeData] = None,
116
            check_nans: bool = True,
117
            num_spatial_dims: Optional[int] = None,  # TODO: document
118
            channels_last: bool = False,  # TODO: document
119
            **kwargs: Dict[str, Any],
120
            ):
121
        if path is None and tensor is None:
122
            raise ValueError('A value for path or tensor must be given')
123
        self._loaded = False
124
        self.num_spatial_dims = num_spatial_dims
125
        self.channels_last = channels_last
126
        tensor = self.parse_tensor(tensor)
127
        affine = self.parse_affine(affine)
128
        if tensor is not None:
129
            if affine is None:
130
                affine = np.eye(4)
131
            self[DATA] = tensor
132
            self[AFFINE] = affine
133
            self._loaded = True
134
        for key in PROTECTED_KEYS:
135
            if key in kwargs:
136
                message = f'Key "{key}" is reserved. Use a different one'
137
                raise ValueError(message)
138
139
        super().__init__(**kwargs)
140
        self.path = self._parse_path(path)
141
        self[PATH] = '' if self.path is None else str(self.path)
142
        self[STEM] = '' if self.path is None else get_stem(self.path)
143
        self[TYPE] = type
144
        self.check_nans = check_nans
145
146
    def __repr__(self):
147
        properties = []
148
        if self._loaded:
149
            properties.extend([
150
                f'shape: {self.shape}',
151
                f'spacing: {self.get_spacing_string()}',
152
                f'orientation: {"".join(self.orientation)}+',
153
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
154
            ])
155
        else:
156
            properties.append(f'path: "{self.path}"')
157
        properties.append(f'type: {self.type}')
158
        properties = '; '.join(properties)
159
        string = f'{self.__class__.__name__}({properties})'
160
        return string
161
162
    def __getitem__(self, item):
163
        if item in (DATA, AFFINE):
164
            if item not in self:
165
                self.load()
166
        return super().__getitem__(item)
167
168
    @property
169
    def data(self):
170
        return self[DATA]
171
172
    @property
173
    def tensor(self):
174
        return self.data
175
176
    @property
177
    def affine(self):
178
        return self[AFFINE]
179
180
    @property
181
    def type(self):
182
        return self[TYPE]
183
184
    @property
185
    def shape(self) -> Tuple[int, int, int, int]:
186
        return tuple(self.data.shape)
187
188
    @property
189
    def spatial_shape(self) -> TypeTripletInt:
190
        return self.shape[1:]
191
192
    @property
193
    def orientation(self):
194
        return nib.aff2axcodes(self.affine)
195
196
    @property
197
    def spacing(self):
198
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
199
        return tuple(spacing)
200
201
    @property
202
    def memory(self):
203
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
204
205
    def get_spacing_string(self):
206
        strings = [f'{n:.2f}' for n in self.spacing]
207
        string = f'({", ".join(strings)})'
208
        return string
209
210
    @staticmethod
211
    def _parse_path(path: TypePath) -> Path:
212
        if path is None:
213
            return None
214
        try:
215
            path = Path(path).expanduser()
216
        except TypeError:
217
            message = f'Conversion to path not possible for variable: {path}'
218
            raise TypeError(message)
219
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
220
            raise FileNotFoundError(f'File not found: {path}')
221
        return path
222
223
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
224
        if tensor is None:
225
            return None
226
        if isinstance(tensor, np.ndarray):
227
            tensor = torch.from_numpy(tensor)
228
        tensor = self.parse_tensor_shape(tensor)
229
        return tensor
230
231
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
232
        return ensure_4d(tensor, self.channels_last, self.num_spatial_dims)
233
234
    @staticmethod
235
    def parse_affine(affine: np.ndarray) -> np.ndarray:
236
        if affine is None:
237
            return np.eye(4)
238
        if not isinstance(affine, np.ndarray):
239
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
240
        if affine.shape != (4, 4):
241
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
242
        return affine
243
244
    def load(self) -> Tuple[torch.Tensor, np.ndarray]:
245
        r"""Load the image from disk.
246
247
        The file is expected to be monomodal/grayscale and 2D or 3D.
248
        A channels dimension is added to the tensor.
249
250
        Returns:
251
            Tuple containing a 4D tensor of size
252
            :math:`(C, D, H, W)`
253
            and a 2D :math:`4 \times 4` affine matrix
254
        """
255
        if self._loaded:
256
            return
257
        if self.path is None:
258
            return
259
        tensor, affine = read_image(self.path)
260
        tensor = self.parse_tensor_shape(tensor)
261
262
        if self.check_nans and torch.isnan(tensor).any():
263
            warnings.warn(f'NaNs found in file "{self.path}"')
264
        self[DATA] = tensor
265
        self[AFFINE] = affine
266
        self._loaded = True
267
268
    def save(self, path, squeeze=False, channels_last=False):
269
        """Save image to disk.
270
271
        Args:
272
            path: String or instance of :py:class:`pathlib.Path`.
273
        """
274
        # Currently not working well for 2D images
275
        tensor = self[DATA]
276
        if channels_last:
277
            tensor = tensor.permute(1, 2, 3, 0)  # (D, H, W, C)
278
        if squeeze:
279
            tensor = tensor.squeeze()
280
        affine = self[AFFINE]
281
        write_image(tensor, affine, path)
282
283
    def is_2d(self) -> bool:
284
        return self.shape[-3] == 1
285
286
    def numpy(self) -> np.ndarray:
287
        """Get a NumPy array containing the image data."""
288
        return self[DATA].numpy()
289
290
    def as_sitk(self) -> sitk.Image:
291
        """Get the image as an instance of :py:class:`sitk.Image`."""
292
        return nib_to_sitk(self[DATA], self[AFFINE])
293
294
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
295
        """Get image center in RAS+ or LPS+ coordinates.
296
297
        Args:
298
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
299
                the first dimension grows towards the left, etc. Otherwise, the
300
                coordinates will be in RAS+ orientation.
301
        """
302
        size = np.array(self.spatial_shape)
303
        center_index = (size - 1) / 2
304
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
305
        if lps:
306
            return (-r, -a, s)
307
        else:
308
            return (r, a, s)
309
310
    def set_check_nans(self, check_nans):
311
        self.check_nans = check_nans
312
313
    def crop(self, index_ini, index_fin):
314
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
315
        new_affine = self.affine.copy()
316
        new_affine[:3, 3] = new_origin
317
        i0, j0, k0 = index_ini
318
        i1, j1, k1 = index_fin
319
        patch = self.data[0, i0:i1, j0:j1, k0:k1].clone()
320
        kwargs = dict(tensor=patch, affine=new_affine, type=self.type, path=self.path)
321
        for key, value in self.items():
322
            if key in PROTECTED_KEYS: continue
323
            kwargs[key] = value  # should I copy? deepcopy?
324
        return self.__class__(**kwargs)
325
326
327
class ScalarImage(Image):
328
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
329
330
    Example:
331
        >>> import torch
332
        >>> import torchio
333
        >>> image = torchio.ScalarImage('t1.nii.gz')  # loading from a file
334
        >>> image = torchio.ScalarImage(tensor=torch.rand(128, 128, 68))  # from tensor
335
        >>> data, affine = image.data, image.affine
336
        >>> affine.shape
337
        (4, 4)
338
        >>> image.data is image[torchio.DATA]
339
        True
340
        >>> image.data is image.tensor
341
        True
342
        >>> type(image.data)
343
        torch.Tensor
344
345
    See :py:class:`~torchio.Image` for more information.
346
347
    Raises:
348
        ValueError: A :py:attr:`type` is used for instantiation.
349
    """
350
    def __init__(self, *args, **kwargs):
351
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
352
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
353
        kwargs.update({'type': INTENSITY})
354
        super().__init__(*args, **kwargs)
355
356
357
class LabelMap(Image):
358
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
359
360
    Example:
361
        >>> import torch
362
        >>> import torchio
363
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
364
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
365
366
    See :py:class:`~torchio.data.image.Image` for more information.
367
368
    Raises:
369
        ValueError: If a value for :py:attr:`type` is given.
370
    """
371
    def __init__(self, *args, **kwargs):
372
        if 'type' in kwargs and kwargs['type'] != LABEL:
373
            raise ValueError('Type of LabelMap is always torchio.LABEL')
374
        kwargs.update({'type': LABEL})
375
        super().__init__(*args, **kwargs)
376