Passed
Push — master ( c291a8...879ee9 )
by Fernando
59s
created

torchio.data.image.Image.numpy()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import nib_to_sitk, get_rotation_and_spacing_from_affine, get_stem
12
from ..torchio import (
13
    TypeData,
14
    TypePath,
15
    TypeTripletInt,
16
    TypeTripletFloat,
17
    DATA,
18
    TYPE,
19
    AFFINE,
20
    PATH,
21
    STEM,
22
    INTENSITY,
23
    LABEL,
24
)
25
from .io import read_image, write_image
26
27
28
class Image(dict):
29
30
    PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
31
32
    r"""TorchIO image.
33
34
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
35
    when needed.
36
37
    Example:
38
        >>> import torchio
39
        >>> image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
40
        >>> image  # not loaded yet
41
        Image(path: t1.nii.gz; type: intensity)
42
        >>> times_two = 2 * image.data  # data is loaded and cached here
43
        >>> image
44
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
45
        >>> image.save('doubled_image.nii.gz')
46
47
    For information about medical image orientation, check out `NiBabel docs`_,
48
    the `3D Slicer wiki`_, `Graham Wideman's website`_ or `FSL docs`_.
49
50
    Args:
51
        path: Path to a file that can be read by
52
            :mod:`SimpleITK` or :mod:`nibabel` or to a directory containing
53
            DICOM files. If :py:attr:`tensor` is given, the data in
54
            :py:attr:`path` will not be read.
55
        type: Type of image, such as :attr:`torchio.INTENSITY` or
56
            :attr:`torchio.LABEL`. This will be used by the transforms to
57
            decide whether to apply an operation, or which interpolation to use
58
            when resampling. For example, `preprocessing`_ and `augmentation`_
59
            intensity transforms will only be applied to images with type
60
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
61
            all types, and nearest neighbor interpolation is always used to
62
            resample images with type :attr:`torchio.LABEL`.
63
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
64
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
65
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 3D
66
            :py:class:`torch.Tensor` or NumPy array with dimensions
67
            :math:`(D, H, W)`.
68
        affine: If :attr:`path` is not given, :attr:`affine` must be a
69
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
70
            identity matrix.
71
        check_nans: If ``True``, issues a warning if NaNs are found
72
            in the image.
73
        **kwargs: Items that will be added to image dictionary within the
74
            subject sample.
75
76
    Example:
77
        >>> import torch
78
        >>> import torchio
79
        >>> # Loading from a file
80
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
81
        >>> # Also:
82
        >>> image = torchio.ScalarImage('t1.nii.gz')
83
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
84
        >>> # Also:
85
        >>> label_image = torchio.LabelMap('t1_seg.nii.gz')
86
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
87
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
88
        >>> data, affine = image.data, image.affine
89
        >>> affine.shape
90
        (4, 4)
91
        >>> image.data is image[torchio.DATA]
92
        True
93
        >>> image.data is image.tensor
94
        True
95
        >>> type(image.data)
96
        torch.Tensor
97
98
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
99
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
100
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
101
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
102
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
103
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
104
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
105
106
    """
107
    def __init__(
108
            self,
109
            path: Optional[TypePath] = None,
110
            type: str = INTENSITY,
111
            tensor: Optional[TypeData] = None,
112
            affine: Optional[TypeData] = None,
113
            check_nans: bool = True,
114
            **kwargs: Dict[str, Any],
115
            ):
116
        if path is None and tensor is None:
117
            raise ValueError('A value for path or tensor must be given')
118
        # if path is not None:
119
        #     if tensor is not None or affine is not None:
120
        #         message = 'If a path is given, tensor and affine must be None'
121
        #         raise ValueError(message)
122
        self._loaded = False
123
        tensor = self.parse_tensor(tensor)
124
        affine = self.parse_affine(affine)
125
        if tensor is not None:
126
            if affine is None:
127
                affine = np.eye(4)
128
            self[DATA] = tensor
129
            self[AFFINE] = affine
130
            self._loaded = True
131
        for key in self.PROTECTED_KEYS:
132
            if key in kwargs:
133
                message = f'Key "{key}" is reserved. Use a different one'
134
                raise ValueError(message)
135
136
        super().__init__(**kwargs)
137
        self.path = self._parse_path(path)
138
        self[PATH] = '' if self.path is None else str(self.path)
139
        self[STEM] = '' if self.path is None else get_stem(self.path)
140
        self[TYPE] = type
141
        self.check_nans = check_nans
142
143
    def __repr__(self):
144
        properties = []
145
        if self._loaded:
146
            properties.extend([
147
                f'shape: {self.shape}',
148
                f'spacing: {self.get_spacing_string()}',
149
                f'orientation: {"".join(self.orientation)}+',
150
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
151
            ])
152
        else:
153
            properties.append(f'path: "{self.path}"')
154
        properties.append(f'type: {self.type}')
155
        properties = '; '.join(properties)
156
        string = f'{self.__class__.__name__}({properties})'
157
        return string
158
159
    def __getitem__(self, item):
160
        if item in (DATA, AFFINE):
161
            if item not in self:
162
                self.load()
163
        return super().__getitem__(item)
164
165
    @property
166
    def data(self):
167
        return self[DATA]
168
169
    @property
170
    def tensor(self):
171
        return self.data
172
173
    @property
174
    def affine(self):
175
        return self[AFFINE]
176
177
    @property
178
    def type(self):
179
        return self[TYPE]
180
181
    @property
182
    def shape(self) -> Tuple[int, int, int, int]:
183
        return tuple(self.data.shape)
184
185
    @property
186
    def spatial_shape(self) -> TypeTripletInt:
187
        return self.shape[1:]
188
189
    @property
190
    def orientation(self):
191
        return nib.aff2axcodes(self.affine)
192
193
    @property
194
    def spacing(self):
195
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
196
        return tuple(spacing)
197
198
    @property
199
    def memory(self):
200
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
201
202
    def get_spacing_string(self):
203
        strings = [f'{n:.2f}' for n in self.spacing]
204
        string = f'({", ".join(strings)})'
205
        return string
206
207
    @staticmethod
208
    def _parse_path(path: TypePath) -> Path:
209
        if path is None:
210
            return None
211
        try:
212
            path = Path(path).expanduser()
213
        except TypeError:
214
            message = f'Conversion to path not possible for variable: {path}'
215
            raise TypeError(message)
216
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
217
            raise FileNotFoundError(f'File not found: {path}')
218
        return path
219
220
    @staticmethod
221
    def parse_tensor(tensor: TypeData) -> torch.Tensor:
222
        if tensor is None:
223
            return None
224
        if isinstance(tensor, np.ndarray):
225
            tensor = torch.from_numpy(tensor)
226
        num_dimensions = tensor.dim()
227
        if num_dimensions != 3:
228
            message = (
229
                'The input tensor must have 3 dimensions (D, H, W),'
230
                f' but has {num_dimensions}: {tensor.shape}'
231
            )
232
            raise RuntimeError(message)
233
        tensor = tensor.unsqueeze(0)  # add channels dimension
234
        tensor = tensor.float()
235
        return tensor
236
237
    @staticmethod
238
    def parse_affine(affine: np.ndarray) -> np.ndarray:
239
        if affine is None:
240
            return np.eye(4)
241
        if not isinstance(affine, np.ndarray):
242
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
243
        if affine.shape != (4, 4):
244
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
245
        return affine
246
247
    def load(self) -> Tuple[torch.Tensor, np.ndarray]:
248
        r"""Load the image from disk.
249
250
        The file is expected to be monomodal/grayscale and 2D or 3D.
251
        A channels dimension is added to the tensor.
252
253
        Returns:
254
            Tuple containing a 4D data tensor of size
255
            :math:`(1, D_{in}, H_{in}, W_{in})`
256
            and a 2D 4x4 affine matrix
257
        """
258
        if self._loaded:
259
            return
260
        if self.path is None:
261
            return
262
        tensor, affine = read_image(self.path)
263
        # https://github.com/pytorch/pytorch/issues/9410#issuecomment-404968513
264
        tensor = tensor[(None,) * (3 - tensor.ndim)]  # force to be 3D
265
        # Remove next line and uncomment the two following ones once/if this issue
266
        # gets fixed:
267
        # https://github.com/pytorch/pytorch/issues/29010
268
        # See also https://discuss.pytorch.org/t/collating-named-tensors/78650/4
269
        tensor = tensor.unsqueeze(0)  # add channels dimension
270
        # name_dimensions(tensor, affine)
271
        # tensor = tensor.align_to('channels', ...)
272
        if self.check_nans and torch.isnan(tensor).any():
273
            warnings.warn(f'NaNs found in file "{self.path}"')
274
        self[DATA] = tensor
275
        self[AFFINE] = affine
276
        self._loaded = True
277
278
    def save(self, path):
279
        """Save image to disk.
280
281
        Args:
282
            path: String or instance of :py:class:`pathlib.Path`.
283
        """
284
        tensor = self[DATA].squeeze()  # assume 2D if (1, 1, H, W)
285
        affine = self[AFFINE]
286
        write_image(tensor, affine, path)
287
288
    def is_2d(self) -> bool:
289
        return self.shape[-3] == 1
290
291
    def numpy(self) -> np.ndarray:
292
        """Get a NumPy array containing the image data."""
293
        return self[DATA].numpy()
294
295
    def as_sitk(self) -> sitk.Image:
296
        """Get the image as an instance of :py:class:`sitk.Image`."""
297
        return nib_to_sitk(self[DATA][0], self[AFFINE])
298
299
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
300
        """Get image center in RAS+ or LPS+ coordinates.
301
302
        Args:
303
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
304
                the first dimension grows towards the left, etc. Otherwise, the
305
                coordinates will be in RAS+ orientation.
306
        """
307
        image = self.as_sitk()
308
        size = np.array(image.GetSize())
309
        center_index = (size - 1) / 2
310
        l, p, s = image.TransformContinuousIndexToPhysicalPoint(center_index)
311
        if lps:
312
            return (l, p, s)
313
        else:
314
            return (-l, -p, s)
315
316
    def set_check_nans(self, check_nans):
317
        self.check_nans = check_nans
318
319
    def crop(self, index_ini, index_fin):
320
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
321
        new_affine = self.affine.copy()
322
        new_affine[:3, 3] = new_origin
323
        i0, j0, k0 = index_ini
324
        i1, j1, k1 = index_fin
325
        patch = self.data[0, i0:i1, j0:j1, k0:k1].clone()
326
        kwargs = dict(tensor=patch, affine=new_affine, type=self.type, path=self.path)
327
        for key, value in self.items():
328
            if key in self.PROTECTED_KEYS: continue
329
            kwargs[key] = value  # should I copy? deepcopy?
330
        return self.__class__(**kwargs)
331
332
333
class ScalarImage(Image):
334
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
335
336
    See :py:class:`~torchio.Image` for more information.
337
338
    Raises:
339
        ValueError: A :py:attr:`type` is used for instantiation.
340
    """
341
    def __init__(self, *args, **kwargs):
342
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
343
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
344
        kwargs.update({'type': INTENSITY})
345
        super().__init__(*args, **kwargs)
346
347
348
class LabelMap(Image):
349
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
350
351
    See :py:class:`~torchio.Image` for more information.
352
353
    Raises:
354
        ValueError: A :py:attr:`type` is used for instantiation.
355
    """
356
    def __init__(self, *args, **kwargs):
357
        if 'type' in kwargs and kwargs['type'] != LABEL:
358
            raise ValueError('Type of LabelMap is always torchio.LABEL')
359
        kwargs.update({'type': LABEL})
360
        super().__init__(*args, **kwargs)
361