Passed
Pull Request — master (#246)
by Fernando
58s
created

torchio.data.image.Image.__init__()   C

Complexity

Conditions 9

Size

Total Lines 35
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 33
nop 9
dl 0
loc 35
rs 6.6666
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import nib_to_sitk, get_rotation_and_spacing_from_affine, get_stem
12
from ..torchio import (
13
    TypeData,
14
    TypePath,
15
    TypeTripletInt,
16
    TypeTripletFloat,
17
    DATA,
18
    TYPE,
19
    AFFINE,
20
    PATH,
21
    STEM,
22
    INTENSITY,
23
    LABEL,
24
    REPO_URL,
25
)
26
from .io import read_image, write_image
27
28
29
class Image(dict):
30
31
    PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
32
33
    r"""TorchIO image.
34
35
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
36
    when needed.
37
38
    Example:
39
        >>> import torchio
40
        >>> image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
41
        >>> image  # not loaded yet
42
        Image(path: t1.nii.gz; type: intensity)
43
        >>> times_two = 2 * image.data  # data is loaded and cached here
44
        >>> image
45
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
46
        >>> image.save('doubled_image.nii.gz')
47
48
    For information about medical image orientation, check out `NiBabel docs`_,
49
    the `3D Slicer wiki`_, `Graham Wideman's website`_ or `FSL docs`_.
50
51
    Args:
52
        path: Path to a file that can be read by
53
            :mod:`SimpleITK` or :mod:`nibabel` or to a directory containing
54
            DICOM files. If :py:attr:`tensor` is given, the data in
55
            :py:attr:`path` will not be read.    # TODO: document dims
56
        type: Type of image, such as :attr:`torchio.INTENSITY` or
57
            :attr:`torchio.LABEL`. This will be used by the transforms to
58
            decide whether to apply an operation, or which interpolation to use
59
            when resampling. For example, `preprocessing`_ and `augmentation`_
60
            intensity transforms will only be applied to images with type
61
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
62
            all types, and nearest neighbor interpolation is always used to
63
            resample images with type :attr:`torchio.LABEL`.
64
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
65
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
66
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 3D
67
            :py:class:`torch.Tensor` or NumPy array with dimensions
68
            :math:`(D, H, W)`.   # TODO: document
69
        affine: If :attr:`path` is not given, :attr:`affine` must be a
70
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
71
            identity matrix.
72
        check_nans: If ``True``, issues a warning if NaNs are found
73
            in the image. If ``False``, images will not be checked for the
74
            presence of NaNs.
75
        **kwargs: Items that will be added to image dictionary within the
76
            subject sample.
77
78
    Example:
79
        >>> import torch
80
        >>> import torchio
81
        >>> # Loading from a file
82
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
83
        >>> # Also:
84
        >>> image = torchio.ScalarImage('t1.nii.gz')
85
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
86
        >>> # Also:
87
        >>> label_image = torchio.LabelMap('t1_seg.nii.gz')
88
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
89
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
90
        >>> data, affine = image.data, image.affine
91
        >>> affine.shape
92
        (4, 4)
93
        >>> image.data is image[torchio.DATA]
94
        True
95
        >>> image.data is image.tensor
96
        True
97
        >>> type(image.data)
98
        torch.Tensor
99
100
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
101
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
102
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
103
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
104
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
105
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
106
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
107
108
    """
109
    def __init__(
110
            self,
111
            path: Optional[TypePath] = None,
112
            type: str = INTENSITY,
113
            tensor: Optional[TypeData] = None,
114
            affine: Optional[TypeData] = None,
115
            check_nans: bool = True,
116
            spatial_dimensions: Optional[int] = None,  # TODO: document
117
            channels_last: bool = True,  # TODO: document
118
            **kwargs: Dict[str, Any],
119
            ):
120
        if path is None and tensor is None:
121
            raise ValueError('A value for path or tensor must be given')
122
        self._loaded = False
123
        self.spatial_dimensions = spatial_dimensions
124
        self.channels_last = channels_last
125
        tensor = self.parse_tensor(tensor)
126
        affine = self.parse_affine(affine)
127
        if tensor is not None:
128
            if affine is None:
129
                affine = np.eye(4)
130
            self[DATA] = tensor
131
            self[AFFINE] = affine
132
            self._loaded = True
133
        for key in self.PROTECTED_KEYS:
134
            if key in kwargs:
135
                message = f'Key "{key}" is reserved. Use a different one'
136
                raise ValueError(message)
137
138
        super().__init__(**kwargs)
139
        self.path = self._parse_path(path)
140
        self[PATH] = '' if self.path is None else str(self.path)
141
        self[STEM] = '' if self.path is None else get_stem(self.path)
142
        self[TYPE] = type
143
        self.check_nans = check_nans
144
145
    def __repr__(self):
146
        properties = []
147
        if self._loaded:
148
            properties.extend([
149
                f'shape: {self.shape}',
150
                f'spacing: {self.get_spacing_string()}',
151
                f'orientation: {"".join(self.orientation)}+',
152
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
153
            ])
154
        else:
155
            properties.append(f'path: "{self.path}"')
156
        properties.append(f'type: {self.type}')
157
        properties = '; '.join(properties)
158
        string = f'{self.__class__.__name__}({properties})'
159
        return string
160
161
    def __getitem__(self, item):
162
        if item in (DATA, AFFINE):
163
            if item not in self:
164
                self.load()
165
        return super().__getitem__(item)
166
167
    @property
168
    def data(self):
169
        return self[DATA]
170
171
    @property
172
    def tensor(self):
173
        return self.data
174
175
    @property
176
    def affine(self):
177
        return self[AFFINE]
178
179
    @property
180
    def type(self):
181
        return self[TYPE]
182
183
    @property
184
    def shape(self) -> Tuple[int, int, int, int]:
185
        return tuple(self.data.shape)
186
187
    @property
188
    def spatial_shape(self) -> TypeTripletInt:
189
        return self.shape[1:]
190
191
    @property
192
    def orientation(self):
193
        return nib.aff2axcodes(self.affine)
194
195
    @property
196
    def spacing(self):
197
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
198
        return tuple(spacing)
199
200
    @property
201
    def memory(self):
202
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
203
204
    def get_spacing_string(self):
205
        strings = [f'{n:.2f}' for n in self.spacing]
206
        string = f'({", ".join(strings)})'
207
        return string
208
209
    @staticmethod
210
    def _parse_path(path: TypePath) -> Path:
211
        if path is None:
212
            return None
213
        try:
214
            path = Path(path).expanduser()
215
        except TypeError:
216
            message = f'Conversion to path not possible for variable: {path}'
217
            raise TypeError(message)
218
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
219
            raise FileNotFoundError(f'File not found: {path}')
220
        return path
221
222
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
223
        if tensor is None:
224
            return None
225
        if isinstance(tensor, np.ndarray):
226
            tensor = torch.from_numpy(tensor)
227
        tensor = self.parse_tensor_shape(tensor)
228
        return tensor
229
230
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
231
        # I wish named tensors were properly supported in PyTorch
232
        num_dimensions = tensor.ndim
233
        if num_dimensions == 5:  # hope (X, X, X, 1, X)
234
            if tensor.shape[-1] == 1:
235
                tensor = tensor[..., 0, :]
236
        if num_dimensions == 4:  # assume 3D multichannel
237
            if self.channels_last:  # (D, H, W, C)
238
                tensor = tensor.permute(3, 0, 1, 2)  # (C, D, H, W)
239
        elif num_dimensions == 2:  # assume 2D monochannel (H, W)
240
            tensor = tensor[None, None]  # (1, 1, H, W)
241
        elif num_dimensions == 3:  # 2D multichannel or 3D monochannel?
242
            if self.spatial_dimensions == 2:
243
                if self.channels_last:  # (H, W, C)
244
                    tensor = tensor.permute(2, 0, 1)  # (C, H, W)
245
                tensor = tensor.unsqueeze(1)  # (C, 1, H, W)
246
            elif self.spatial_dimensions == 3:  # (D, H, W)
247
                tensor = tensor[None]  # (1, D, H, W)
248
            else:  # try to guess
249
                shape = tensor.shape
250
                maybe_rgb = 3 in shape
251
                if maybe_rgb:
252
                    if shape[-1] == 3:  # (H, W, 3)
253
                        tensor = tensor.permute(2, 0, 1)  # (3, H, W)
254
                    tensor = tensor.unsqueeze(1)  # (3, 1, H, W)
255
                else:  # (D, H, W)
256
                    tensor = tensor[None]  # (1, D, H, W)
257
        else:
258
            message = (
259
                f'{num_dimensions}D images not supported yet. Please create an'
260
                f' issue in {REPO_URL} if you would like support for them'
261
            )
262
            raise ValueError(message)
263
        assert tensor.ndim == 4
264
        return tensor
265
266
    @staticmethod
267
    def parse_affine(affine: np.ndarray) -> np.ndarray:
268
        if affine is None:
269
            return np.eye(4)
270
        if not isinstance(affine, np.ndarray):
271
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
272
        if affine.shape != (4, 4):
273
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
274
        return affine
275
276
    def load(self) -> Tuple[torch.Tensor, np.ndarray]:
277
        r"""Load the image from disk.
278
279
        The file is expected to be monomodal/grayscale and 2D or 3D.
280
        A channels dimension is added to the tensor.
281
282
        Returns:
283
            Tuple containing a 4D data tensor of size
284
            :math:`(1, D_{in}, H_{in}, W_{in})`
285
            and a 2D 4x4 affine matrix
286
        """
287
        if self._loaded:
288
            return
289
        if self.path is None:
290
            return
291
        tensor, affine = read_image(self.path)
292
        tensor = self.parse_tensor_shape(tensor)
293
294
        if self.check_nans and torch.isnan(tensor).any():
295
            warnings.warn(f'NaNs found in file "{self.path}"')
296
        self[DATA] = tensor
297
        self[AFFINE] = affine
298
        self._loaded = True
299
300
    def save(self, path):
301
        """Save image to disk.
302
303
        Args:
304
            path: String or instance of :py:class:`pathlib.Path`.
305
        """
306
        tensor = self[DATA]
307
        if self.channels_last:
308
            tensor = tensor.permute(1, 2, 3, 0)
309
        tensor = tensor.squeeze()
310
        affine = self[AFFINE]
311
        write_image(tensor, affine, path)
312
313
    def is_2d(self) -> bool:
314
        return self.shape[-3] == 1
315
316
    def numpy(self) -> np.ndarray:
317
        """Get a NumPy array containing the image data."""
318
        return self[DATA].numpy()
319
320
    def as_sitk(self) -> sitk.Image:
321
        """Get the image as an instance of :py:class:`sitk.Image`."""
322
        return nib_to_sitk(self[DATA], self[AFFINE])
323
324
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
325
        """Get image center in RAS+ or LPS+ coordinates.
326
327
        Args:
328
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
329
                the first dimension grows towards the left, etc. Otherwise, the
330
                coordinates will be in RAS+ orientation.
331
        """
332
        image = self.as_sitk()
333
        size = np.array(image.GetSize())
334
        center_index = (size - 1) / 2
335
        l, p, s = image.TransformContinuousIndexToPhysicalPoint(center_index)
336
        if lps:
337
            return (l, p, s)
338
        else:
339
            return (-l, -p, s)
340
341
    def set_check_nans(self, check_nans):
342
        self.check_nans = check_nans
343
344
    def crop(self, index_ini, index_fin):
345
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
346
        new_affine = self.affine.copy()
347
        new_affine[:3, 3] = new_origin
348
        i0, j0, k0 = index_ini
349
        i1, j1, k1 = index_fin
350
        patch = self.data[0, i0:i1, j0:j1, k0:k1].clone()
351
        kwargs = dict(tensor=patch, affine=new_affine, type=self.type, path=self.path)
352
        for key, value in self.items():
353
            if key in self.PROTECTED_KEYS: continue
354
            kwargs[key] = value  # should I copy? deepcopy?
355
        return self.__class__(**kwargs)
356
357
358
class ScalarImage(Image):
359
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
360
361
    See :py:class:`~torchio.Image` for more information.
362
363
    Raises:
364
        ValueError: A :py:attr:`type` is used for instantiation.
365
    """
366
    def __init__(self, *args, **kwargs):
367
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
368
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
369
        kwargs.update({'type': INTENSITY})
370
        super().__init__(*args, **kwargs)
371
372
373
class LabelMap(Image):
374
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
375
376
    See :py:class:`~torchio.Image` for more information.
377
378
    Raises:
379
        ValueError: A :py:attr:`type` is used for instantiation.
380
    """
381
    def __init__(self, *args, **kwargs):
382
        if 'type' in kwargs and kwargs['type'] != LABEL:
383
            raise ValueError('Type of LabelMap is always torchio.LABEL')
384
        kwargs.update({'type': LABEL})
385
        super().__init__(*args, **kwargs)
386