Passed
Pull Request — master (#246)
by Fernando
01:08
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 33
Code Lines 31

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 31
nop 8
dl 0
loc 33
rs 6.6666
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import nib_to_sitk, get_rotation_and_spacing_from_affine, get_stem
12
from ..torchio import (
13
    TypeData,
14
    TypePath,
15
    TypeTripletInt,
16
    TypeTripletFloat,
17
    DATA,
18
    TYPE,
19
    AFFINE,
20
    PATH,
21
    STEM,
22
    INTENSITY,
23
    LABEL,
24
    REPO_URL,
25
)
26
from .io import read_image, write_image
27
28
29
class Image(dict):
30
31
    PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
32
33
    r"""TorchIO image.
34
35
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
36
    when needed.
37
38
    Example:
39
        >>> import torchio
40
        >>> image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
41
        >>> image  # not loaded yet
42
        Image(path: t1.nii.gz; type: intensity)
43
        >>> times_two = 2 * image.data  # data is loaded and cached here
44
        >>> image
45
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
46
        >>> image.save('doubled_image.nii.gz')
47
48
    For information about medical image orientation, check out `NiBabel docs`_,
49
    the `3D Slicer wiki`_, `Graham Wideman's website`_ or `FSL docs`_.
50
51
    Args:
52
        path: Path to a file that can be read by
53
            :mod:`SimpleITK` or :mod:`nibabel` or to a directory containing
54
            DICOM files. If :py:attr:`tensor` is given, the data in
55
            :py:attr:`path` will not be read.
56
        type: Type of image, such as :attr:`torchio.INTENSITY` or
57
            :attr:`torchio.LABEL`. This will be used by the transforms to
58
            decide whether to apply an operation, or which interpolation to use
59
            when resampling. For example, `preprocessing`_ and `augmentation`_
60
            intensity transforms will only be applied to images with type
61
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
62
            all types, and nearest neighbor interpolation is always used to
63
            resample images with type :attr:`torchio.LABEL`.
64
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
65
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
66
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 3D
67
            :py:class:`torch.Tensor` or NumPy array with dimensions
68
            :math:`(D, H, W)`.
69
        affine: If :attr:`path` is not given, :attr:`affine` must be a
70
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
71
            identity matrix.
72
        check_nans: If ``True``, issues a warning if NaNs are found
73
            in the image.
74
        **kwargs: Items that will be added to image dictionary within the
75
            subject sample.
76
77
    Example:
78
        >>> import torch
79
        >>> import torchio
80
        >>> # Loading from a file
81
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
82
        >>> # Also:
83
        >>> image = torchio.ScalarImage('t1.nii.gz')
84
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
85
        >>> # Also:
86
        >>> label_image = torchio.LabelMap('t1_seg.nii.gz')
87
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
88
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
89
        >>> data, affine = image.data, image.affine
90
        >>> affine.shape
91
        (4, 4)
92
        >>> image.data is image[torchio.DATA]
93
        True
94
        >>> image.data is image.tensor
95
        True
96
        >>> type(image.data)
97
        torch.Tensor
98
99
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
100
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
101
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
102
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
103
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
104
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
105
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
106
107
    """
108
    def __init__(
109
            self,
110
            path: Optional[TypePath] = None,
111
            type: str = INTENSITY,
112
            tensor: Optional[TypeData] = None,
113
            affine: Optional[TypeData] = None,
114
            check_nans: bool = True,
115
            spatial_dimensions: Optional[int] = None,
116
            **kwargs: Dict[str, Any],
117
            ):
118
        if path is None and tensor is None:
119
            raise ValueError('A value for path or tensor must be given')
120
        self._loaded = False
121
        self.spatial_dimensions = spatial_dimensions
122
        tensor = self.parse_tensor(tensor)
123
        affine = self.parse_affine(affine)
124
        if tensor is not None:
125
            if affine is None:
126
                affine = np.eye(4)
127
            self[DATA] = tensor
128
            self[AFFINE] = affine
129
            self._loaded = True
130
        for key in self.PROTECTED_KEYS:
131
            if key in kwargs:
132
                message = f'Key "{key}" is reserved. Use a different one'
133
                raise ValueError(message)
134
135
        super().__init__(**kwargs)
136
        self.path = self._parse_path(path)
137
        self[PATH] = '' if self.path is None else str(self.path)
138
        self[STEM] = '' if self.path is None else get_stem(self.path)
139
        self[TYPE] = type
140
        self.check_nans = check_nans
141
142
    def __repr__(self):
143
        properties = []
144
        if self._loaded:
145
            properties.extend([
146
                f'shape: {self.shape}',
147
                f'spacing: {self.get_spacing_string()}',
148
                f'orientation: {"".join(self.orientation)}+',
149
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
150
            ])
151
        else:
152
            properties.append(f'path: "{self.path}"')
153
        properties.append(f'type: {self.type}')
154
        properties = '; '.join(properties)
155
        string = f'{self.__class__.__name__}({properties})'
156
        return string
157
158
    def __getitem__(self, item):
159
        if item in (DATA, AFFINE):
160
            if item not in self:
161
                self.load()
162
        return super().__getitem__(item)
163
164
    @property
165
    def data(self):
166
        return self[DATA]
167
168
    @property
169
    def tensor(self):
170
        return self.data
171
172
    @property
173
    def affine(self):
174
        return self[AFFINE]
175
176
    @property
177
    def type(self):
178
        return self[TYPE]
179
180
    @property
181
    def shape(self) -> Tuple[int, int, int, int]:
182
        return tuple(self.data.shape)
183
184
    @property
185
    def spatial_shape(self) -> TypeTripletInt:
186
        return self.shape[1:]
187
188
    @property
189
    def orientation(self):
190
        return nib.aff2axcodes(self.affine)
191
192
    @property
193
    def spacing(self):
194
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
195
        return tuple(spacing)
196
197
    @property
198
    def memory(self):
199
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
200
201
    def get_spacing_string(self):
202
        strings = [f'{n:.2f}' for n in self.spacing]
203
        string = f'({", ".join(strings)})'
204
        return string
205
206
    @staticmethod
207
    def _parse_path(path: TypePath) -> Path:
208
        if path is None:
209
            return None
210
        try:
211
            path = Path(path).expanduser()
212
        except TypeError:
213
            message = f'Conversion to path not possible for variable: {path}'
214
            raise TypeError(message)
215
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
216
            raise FileNotFoundError(f'File not found: {path}')
217
        return path
218
219
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
220
        if tensor is None:
221
            return None
222
        if isinstance(tensor, np.ndarray):
223
            tensor = torch.from_numpy(tensor)
224
        tensor = self.parse_tensor_shape(tensor)
225
        return tensor
226
227
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
228
        num_dimensions = tensor.ndim
229
        if num_dimensions == 4:  # assume 3D multichannel (C, D, H, W)
230
            pass
231
        elif num_dimensions == 2:  # assume 2D monochannel (1, 1, H, W)
232
            tensor = tensor[None, None]
233
        elif num_dimensions == 3:  # 2D multichannel or 3D monochannel?
234
            if self.spatial_dimensions == 2:
235
                # Assume (H, W, C)
236
                tensor = tensor.permute(2, 0, 1)[None]  # (C, 1, H, W)
237
            elif self.spatial_dimensions == 3:
238
                tensor = tensor[None]  # (1, D, H, W)
239
            else:  # try to guess
240
                shape = tensor.shape
241
                maybe_rgb = 3 in shape
242
                if maybe_rgb:
243
                    if shape[-1] == 3:  # (H, W, 3)
244
                        tensor = tensor.permute(2, 0, 1)[None]  # (3, 1, H, W)
245
                    elif shape[0] == 3:  # (3, H, W)
246
                        tensor = tensor.unsqueeze(1)  # (3, 1, H, W)
247
                else:  # (D, H, W)
248
                    tensor = tensor[None]  # (1, D, H, W)
249
        else:
250
            message = (
251
                f'{num_dimensions}D images not supported yet. Please create an'
252
                f' issue in {REPO_URL} if you would like support for them'
253
            )
254
            raise ValueError(message)
255
        assert tensor.ndim == 4
256
        return tensor
257
258
    @staticmethod
259
    def parse_affine(affine: np.ndarray) -> np.ndarray:
260
        if affine is None:
261
            return np.eye(4)
262
        if not isinstance(affine, np.ndarray):
263
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
264
        if affine.shape != (4, 4):
265
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
266
        return affine
267
268
    def load(self) -> Tuple[torch.Tensor, np.ndarray]:
269
        r"""Load the image from disk.
270
271
        The file is expected to be monomodal/grayscale and 2D or 3D.
272
        A channels dimension is added to the tensor.
273
274
        Returns:
275
            Tuple containing a 4D data tensor of size
276
            :math:`(1, D_{in}, H_{in}, W_{in})`
277
            and a 2D 4x4 affine matrix
278
        """
279
        if self._loaded:
280
            return
281
        if self.path is None:
282
            return
283
        tensor, affine = read_image(self.path)
284
        tensor = self.parse_tensor_shape(tensor)
285
286
        if self.check_nans and torch.isnan(tensor).any():
287
            warnings.warn(f'NaNs found in file "{self.path}"')
288
        self[DATA] = tensor
289
        self[AFFINE] = affine
290
        self._loaded = True
291
292
    def save(self, path):
293
        """Save image to disk.
294
295
        Args:
296
            path: String or instance of :py:class:`pathlib.Path`.
297
        """
298
        tensor = self[DATA].squeeze()  # assume 2D if (1, 1, H, W)
299
        affine = self[AFFINE]
300
        write_image(tensor, affine, path)
301
302
    def is_2d(self) -> bool:
303
        return self.shape[-3] == 1
304
305
    def numpy(self) -> np.ndarray:
306
        """Get a NumPy array containing the image data."""
307
        return self[DATA].numpy()
308
309
    def as_sitk(self) -> sitk.Image:
310
        """Get the image as an instance of :py:class:`sitk.Image`."""
311
        return nib_to_sitk(self[DATA][0], self[AFFINE])
312
313
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
314
        """Get image center in RAS+ or LPS+ coordinates.
315
316
        Args:
317
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
318
                the first dimension grows towards the left, etc. Otherwise, the
319
                coordinates will be in RAS+ orientation.
320
        """
321
        image = self.as_sitk()
322
        size = np.array(image.GetSize())
323
        center_index = (size - 1) / 2
324
        l, p, s = image.TransformContinuousIndexToPhysicalPoint(center_index)
325
        if lps:
326
            return (l, p, s)
327
        else:
328
            return (-l, -p, s)
329
330
    def set_check_nans(self, check_nans):
331
        self.check_nans = check_nans
332
333
    def crop(self, index_ini, index_fin):
334
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
335
        new_affine = self.affine.copy()
336
        new_affine[:3, 3] = new_origin
337
        i0, j0, k0 = index_ini
338
        i1, j1, k1 = index_fin
339
        patch = self.data[0, i0:i1, j0:j1, k0:k1].clone()
340
        kwargs = dict(tensor=patch, affine=new_affine, type=self.type, path=self.path)
341
        for key, value in self.items():
342
            if key in self.PROTECTED_KEYS: continue
343
            kwargs[key] = value  # should I copy? deepcopy?
344
        return self.__class__(**kwargs)
345
346
347
class ScalarImage(Image):
348
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
349
350
    See :py:class:`~torchio.Image` for more information.
351
352
    Raises:
353
        ValueError: A :py:attr:`type` is used for instantiation.
354
    """
355
    def __init__(self, *args, **kwargs):
356
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
357
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
358
        kwargs.update({'type': INTENSITY})
359
        super().__init__(*args, **kwargs)
360
361
362
class LabelMap(Image):
363
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
364
365
    See :py:class:`~torchio.Image` for more information.
366
367
    Raises:
368
        ValueError: A :py:attr:`type` is used for instantiation.
369
    """
370
    def __init__(self, *args, **kwargs):
371
        if 'type' in kwargs and kwargs['type'] != LABEL:
372
            raise ValueError('Type of LabelMap is always torchio.LABEL')
373
        kwargs.update({'type': LABEL})
374
        super().__init__(*args, **kwargs)
375