Passed
Pull Request — master (#246)
by Fernando
58s
created

torchio.data.image.Image.parse_tensor()   A

Complexity

Conditions 3

Size

Total Lines 7
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 7
nop 2
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import nib_to_sitk, get_rotation_and_spacing_from_affine, get_stem
12
from ..torchio import (
13
    TypeData,
14
    TypePath,
15
    TypeTripletInt,
16
    TypeTripletFloat,
17
    DATA,
18
    TYPE,
19
    AFFINE,
20
    PATH,
21
    STEM,
22
    INTENSITY,
23
    LABEL,
24
    REPO_URL,
25
)
26
from .io import read_image, write_image
27
28
29
class Image(dict):
30
31
    PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
32
33
    r"""TorchIO image.
34
35
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
36
    when needed.
37
38
    Example:
39
        >>> import torchio
40
        >>> image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
41
        >>> image  # not loaded yet
42
        Image(path: t1.nii.gz; type: intensity)
43
        >>> times_two = 2 * image.data  # data is loaded and cached here
44
        >>> image
45
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
46
        >>> image.save('doubled_image.nii.gz')
47
48
    For information about medical image orientation, check out `NiBabel docs`_,
49
    the `3D Slicer wiki`_, `Graham Wideman's website`_ or `FSL docs`_.
50
51
    Args:
52
        path: Path to a file that can be read by
53
            :mod:`SimpleITK` or :mod:`nibabel` or to a directory containing
54
            DICOM files. If :py:attr:`tensor` is given, the data in
55
            :py:attr:`path` will not be read.    # TODO: document dims
56
        type: Type of image, such as :attr:`torchio.INTENSITY` or
57
            :attr:`torchio.LABEL`. This will be used by the transforms to
58
            decide whether to apply an operation, or which interpolation to use
59
            when resampling. For example, `preprocessing`_ and `augmentation`_
60
            intensity transforms will only be applied to images with type
61
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
62
            all types, and nearest neighbor interpolation is always used to
63
            resample images with type :attr:`torchio.LABEL`.
64
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
65
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
66
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 3D
67
            :py:class:`torch.Tensor` or NumPy array with dimensions
68
            :math:`(D, H, W)`.   # TODO: document
69
        affine: If :attr:`path` is not given, :attr:`affine` must be a
70
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
71
            identity matrix.
72
        check_nans: If ``True``, issues a warning if NaNs are found
73
            in the image. If ``False``, images will not be checked for the
74
            presence of NaNs.
75
        **kwargs: Items that will be added to image dictionary within the
76
            subject sample.
77
78
    Example:
79
        >>> import torch
80
        >>> import torchio
81
        >>> # Loading from a file
82
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
83
        >>> # Also:
84
        >>> image = torchio.ScalarImage('t1.nii.gz')
85
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
86
        >>> # Also:
87
        >>> label_image = torchio.LabelMap('t1_seg.nii.gz')
88
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
89
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
90
        >>> data, affine = image.data, image.affine
91
        >>> affine.shape
92
        (4, 4)
93
        >>> image.data is image[torchio.DATA]
94
        True
95
        >>> image.data is image.tensor
96
        True
97
        >>> type(image.data)
98
        torch.Tensor
99
100
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
101
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
102
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
103
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
104
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
105
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
106
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
107
108
    """
109
    def __init__(
110
            self,
111
            path: Optional[TypePath] = None,
112
            type: str = INTENSITY,
113
            tensor: Optional[TypeData] = None,
114
            affine: Optional[TypeData] = None,
115
            check_nans: bool = True,
116
            spatial_dimensions: Optional[int] = None,  # TODO: document
117
            channels_last: bool = True,  # TODO: document
118
            **kwargs: Dict[str, Any],
119
            ):
120
        if path is None and tensor is None:
121
            raise ValueError('A value for path or tensor must be given')
122
        self._loaded = False
123
        self.spatial_dimensions = spatial_dimensions
124
        self.channels_last = channels_last
125
        tensor = self.parse_tensor(tensor)
126
        affine = self.parse_affine(affine)
127
        if tensor is not None:
128
            if affine is None:
129
                affine = np.eye(4)
130
            self[DATA] = tensor
131
            self[AFFINE] = affine
132
            self._loaded = True
133
        for key in self.PROTECTED_KEYS:
134
            if key in kwargs:
135
                message = f'Key "{key}" is reserved. Use a different one'
136
                raise ValueError(message)
137
138
        super().__init__(**kwargs)
139
        self.path = self._parse_path(path)
140
        self[PATH] = '' if self.path is None else str(self.path)
141
        self[STEM] = '' if self.path is None else get_stem(self.path)
142
        self[TYPE] = type
143
        self.check_nans = check_nans
144
145
    def __repr__(self):
146
        properties = []
147
        if self._loaded:
148
            properties.extend([
149
                f'shape: {self.shape}',
150
                f'spacing: {self.get_spacing_string()}',
151
                f'orientation: {"".join(self.orientation)}+',
152
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
153
            ])
154
        else:
155
            properties.append(f'path: "{self.path}"')
156
        properties.append(f'type: {self.type}')
157
        properties = '; '.join(properties)
158
        string = f'{self.__class__.__name__}({properties})'
159
        return string
160
161
    def __getitem__(self, item):
162
        if item in (DATA, AFFINE):
163
            if item not in self:
164
                self.load()
165
        return super().__getitem__(item)
166
167
    @property
168
    def data(self):
169
        return self[DATA]
170
171
    @property
172
    def tensor(self):
173
        return self.data
174
175
    @property
176
    def affine(self):
177
        return self[AFFINE]
178
179
    @property
180
    def type(self):
181
        return self[TYPE]
182
183
    @property
184
    def shape(self) -> Tuple[int, int, int, int]:
185
        return tuple(self.data.shape)
186
187
    @property
188
    def spatial_shape(self) -> TypeTripletInt:
189
        return self.shape[1:]
190
191
    @property
192
    def orientation(self):
193
        return nib.aff2axcodes(self.affine)
194
195
    @property
196
    def spacing(self):
197
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
198
        return tuple(spacing)
199
200
    @property
201
    def memory(self):
202
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
203
204
    def get_spacing_string(self):
205
        strings = [f'{n:.2f}' for n in self.spacing]
206
        string = f'({", ".join(strings)})'
207
        return string
208
209
    @staticmethod
210
    def _parse_path(path: TypePath) -> Path:
211
        if path is None:
212
            return None
213
        try:
214
            path = Path(path).expanduser()
215
        except TypeError:
216
            message = f'Conversion to path not possible for variable: {path}'
217
            raise TypeError(message)
218
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
219
            raise FileNotFoundError(f'File not found: {path}')
220
        return path
221
222
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
223
        if tensor is None:
224
            return None
225
        if isinstance(tensor, np.ndarray):
226
            tensor = torch.from_numpy(tensor)
227
        tensor = self.parse_tensor_shape(tensor)
228
        return tensor
229
230
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
231
        # I wish named tensors were properly supported in PyTorch
232
        num_dimensions = tensor.ndim
233
        if num_dimensions == 5:  # hope (X, X, X, 1, X)
234
            if tensor.shape[-1] == 1:
235
                tensor = tensor[..., 0, :]
236
        if num_dimensions == 4:  # assume 3D multichannel
237
            if self.channels_last:  # (D, H, W, C)
238
                tensor = tensor.permute(3, 0, 1, 2)  # (C, D, H, W)
239
        elif num_dimensions == 2:  # assume 2D monochannel (H, W)
240
            tensor = tensor[None, None]  # (1, 1, H, W)
241
        elif num_dimensions == 3:  # 2D multichannel or 3D monochannel?
242
            if self.spatial_dimensions == 2:
243
                if self.channels_last:  # (H, W, C)
244
                    tensor = tensor.permute(2, 0, 1)  # (C, H, W)
245
                tensor = tensor.unsqueeze(1)  # (C, 1, H, W)
246
            elif self.spatial_dimensions == 3:  # (D, H, W)
247
                tensor = tensor[None]  # (1, D, H, W)
248
            else:  # try to guess
249
                shape = tensor.shape
250
                maybe_rgb = 3 in shape
251
                if maybe_rgb:
252
                    if shape[-1] == 3:  # (H, W, 3)
253
                        tensor = tensor.permute(2, 0, 1)  # (3, H, W)
254
                    tensor = tensor.unsqueeze(1)  # (3, 1, H, W)
255
                else:  # (D, H, W)
256
                    tensor = tensor[None]  # (1, D, H, W)
257
        else:
258
            message = (
259
                f'{num_dimensions}D images not supported yet. Please create an'
260
                f' issue in {REPO_URL} if you would like support for them'
261
            )
262
            raise ValueError(message)
263
        assert tensor.ndim == 4
264
        return tensor
265
266
    @staticmethod
267
    def parse_affine(affine: np.ndarray) -> np.ndarray:
268
        if affine is None:
269
            return np.eye(4)
270
        if not isinstance(affine, np.ndarray):
271
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
272
        if affine.shape != (4, 4):
273
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
274
        return affine
275
276
    def load(self) -> Tuple[torch.Tensor, np.ndarray]:
277
        r"""Load the image from disk.
278
279
        The file is expected to be monomodal/grayscale and 2D or 3D.
280
        A channels dimension is added to the tensor.
281
282
        Returns:
283
            Tuple containing a 4D data tensor of size
284
            :math:`(1, D_{in}, H_{in}, W_{in})`
285
            and a 2D 4x4 affine matrix
286
        """
287
        if self._loaded:
288
            return
289
        if self.path is None:
290
            return
291
        tensor, affine = read_image(self.path)
292
        tensor = self.parse_tensor_shape(tensor)
293
294
        if self.check_nans and torch.isnan(tensor).any():
295
            warnings.warn(f'NaNs found in file "{self.path}"')
296
        self[DATA] = tensor
297
        self[AFFINE] = affine
298
        self._loaded = True
299
300
    def save(self, path):
301
        """Save image to disk.
302
303
        Args:
304
            path: String or instance of :py:class:`pathlib.Path`.
305
        """
306
        tensor = self[DATA]
307
        if self.channels_last:
308
            tensor = tensor.permute(1, 2, 3, 0)
309
        tensor = tensor.squeeze()
310
        affine = self[AFFINE]
311
        write_image(tensor, affine, path)
312
313
    def is_2d(self) -> bool:
314
        return self.shape[-3] == 1
315
316
    def numpy(self) -> np.ndarray:
317
        """Get a NumPy array containing the image data."""
318
        return self[DATA].numpy()
319
320
    def as_sitk(self) -> sitk.Image:
321
        """Get the image as an instance of :py:class:`sitk.Image`."""
322
        return nib_to_sitk(self[DATA], self[AFFINE])
323
324
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
325
        """Get image center in RAS+ or LPS+ coordinates.
326
327
        Args:
328
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
329
                the first dimension grows towards the left, etc. Otherwise, the
330
                coordinates will be in RAS+ orientation.
331
        """
332
        image = self.as_sitk()
333
        size = np.array(image.GetSize())
334
        center_index = (size - 1) / 2
335
        l, p, s = image.TransformContinuousIndexToPhysicalPoint(center_index)
336
        if lps:
337
            return (l, p, s)
338
        else:
339
            return (-l, -p, s)
340
341
    def set_check_nans(self, check_nans):
342
        self.check_nans = check_nans
343
344
    def crop(self, index_ini, index_fin):
345
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
346
        new_affine = self.affine.copy()
347
        new_affine[:3, 3] = new_origin
348
        i0, j0, k0 = index_ini
349
        i1, j1, k1 = index_fin
350
        patch = self.data[0, i0:i1, j0:j1, k0:k1].clone()
351
        kwargs = dict(tensor=patch, affine=new_affine, type=self.type, path=self.path)
352
        for key, value in self.items():
353
            if key in self.PROTECTED_KEYS: continue
354
            kwargs[key] = value  # should I copy? deepcopy?
355
        return self.__class__(**kwargs)
356
357
358
class ScalarImage(Image):
359
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
360
361
    See :py:class:`~torchio.Image` for more information.
362
363
    Raises:
364
        ValueError: A :py:attr:`type` is used for instantiation.
365
    """
366
    def __init__(self, *args, **kwargs):
367
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
368
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
369
        kwargs.update({'type': INTENSITY})
370
        super().__init__(*args, **kwargs)
371
372
373
class LabelMap(Image):
374
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
375
376
    See :py:class:`~torchio.Image` for more information.
377
378
    Raises:
379
        ValueError: A :py:attr:`type` is used for instantiation.
380
    """
381
    def __init__(self, *args, **kwargs):
382
        if 'type' in kwargs and kwargs['type'] != LABEL:
383
            raise ValueError('Type of LabelMap is always torchio.LABEL')
384
        kwargs.update({'type': LABEL})
385
        super().__init__(*args, **kwargs)
386