Passed
Push — master ( 1e0c7f...57cb8b )
by Fernando
01:33 queued 36s
created

torchio.data.image.Image.get_bounds()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 8
nop 1
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel`, or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read. The data is expected to be 2D or
48
            3D, and may have multiple channels (see :attr:`num_spatial_dims` and
49
            :attr:`channels_last`).
50
        type: Type of image, such as :attr:`torchio.INTENSITY` or
51
            :attr:`torchio.LABEL`. This will be used by the transforms to
52
            decide whether to apply an operation, or which interpolation to use
53
            when resampling. For example, `preprocessing`_ and `augmentation`_
54
            intensity transforms will only be applied to images with type
55
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
56
            all types, and nearest neighbor interpolation is always used to
57
            resample images with type :attr:`torchio.LABEL`.
58
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
59
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
60
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
61
            :py:class:`torch.Tensor` or NumPy array with dimensions
62
            :math:`(C, D, H, W)`. If it is not 4D, TorchIO will try to guess
63
            the dimensions meanings. If 2D, the shape will be interpreted as
64
            :math:`(H, W)`. If 3D, the number of spatial dimensions should be
65
            determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims`
66
            is not given and the shape is 3 along the first or last dimensions,
67
            it will be interpreted as a multichannel 2D image. Otherwise, it
68
            be interpreted as a 3D image with a single channel.
69
        affine: If :attr:`path` is not given, :attr:`affine` must be a
70
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
71
            identity matrix.
72
        check_nans: If ``True``, issues a warning if NaNs are found
73
            in the image. If ``False``, images will not be checked for the
74
            presence of NaNs.
75
        num_spatial_dims: If ``2`` and the input tensor has 3 dimensions, it
76
            will be interpreted as a multichannel 2D image. If ``3`` and the
77
            input has 3 dimensions, it will be interpreted as a
78
            single-channel 3D volume.
79
        channels_last: If ``True``, the last dimension of the input will be
80
            interpreted as the channels. Defaults to ``True`` if :attr:`path` is
81
            given and ``False`` otherwise.
82
        **kwargs: Items that will be added to the image dictionary, e.g.
83
            acquisition parameters.
84
85
    Example:
86
        >>> import torch
87
        >>> import torchio
88
        >>> # Loading from a file
89
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
90
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
91
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
92
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
93
        >>> data, affine = image.data, image.affine
94
        >>> affine.shape
95
        (4, 4)
96
        >>> image.data is image[torchio.DATA]
97
        True
98
        >>> image.data is image.tensor
99
        True
100
        >>> type(image.data)
101
        torch.Tensor
102
103
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
104
    when needed.
105
106
    Example:
107
        >>> import torchio
108
        >>> image = torchio.Image('t1.nii.gz')
109
        >>> image  # not loaded yet
110
        Image(path: t1.nii.gz; type: intensity)
111
        >>> times_two = 2 * image.data  # data is loaded and cached here
112
        >>> image
113
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
114
        >>> image.save('doubled_image.nii.gz')
115
116
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
117
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
118
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
119
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
120
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
121
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
122
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
123
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
124
    """
125
    def __init__(
126
            self,
127
            path: Optional[TypePath] = None,
128
            type: str = None,
129
            tensor: Optional[TypeData] = None,
130
            affine: Optional[TypeData] = None,
131
            check_nans: bool = True,
132
            num_spatial_dims: Optional[int] = None,
133
            channels_last: Optional[bool] = None,
134
            **kwargs: Dict[str, Any],
135
            ):
136
        self.check_nans = check_nans
137
        self.num_spatial_dims = num_spatial_dims
138
139
        if type is None:
140
            warnings.warn(
141
                'Not specifying the image type is deprecated and will be'
142
                ' mandatory in the future. You can probably use ScalarImage or'
143
                ' LabelMap instead'
144
            )
145
            type = INTENSITY
146
147
        if path is None and tensor is None:
148
            raise ValueError('A value for path or tensor must be given')
149
        self._loaded = False
150
151
        # Number of channels are typically stored in the last dimensions in disk
152
        # But if a tensor is given, the channels should be in the first dim
153
        if channels_last is None:
154
            channels_last = path is not None
155
        self.channels_last = channels_last
156
157
        tensor = self.parse_tensor(tensor)
158
        affine = self.parse_affine(affine)
159
        if tensor is not None:
160
            self[DATA] = tensor
161
            self[AFFINE] = affine
162
            self._loaded = True
163
        for key in PROTECTED_KEYS:
164
            if key in kwargs:
165
                message = f'Key "{key}" is reserved. Use a different one'
166
                raise ValueError(message)
167
168
        super().__init__(**kwargs)
169
        self.path = self._parse_path(path)
170
        self[PATH] = '' if self.path is None else str(self.path)
171
        self[STEM] = '' if self.path is None else get_stem(self.path)
172
        self[TYPE] = type
173
174
    def __repr__(self):
175
        properties = []
176
        if self._loaded:
177
            properties.extend([
178
                f'shape: {self.shape}',
179
                f'spacing: {self.get_spacing_string()}',
180
                f'orientation: {"".join(self.orientation)}+',
181
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
182
            ])
183
        else:
184
            properties.append(f'path: "{self.path}"')
185
        properties.append(f'type: {self.type}')
186
        properties = '; '.join(properties)
187
        string = f'{self.__class__.__name__}({properties})'
188
        return string
189
190
    def __getitem__(self, item):
191
        if item in (DATA, AFFINE):
192
            if item not in self:
193
                self._load()
194
        return super().__getitem__(item)
195
196
    def __array__(self):
197
        return self[DATA].numpy()
198
199
    def __copy__(self):
200
        kwargs = dict(
201
            tensor=self.data,
202
            affine=self.affine,
203
            type=self.type,
204
            path=self.path,
205
            channels_last=False,
206
        )
207
        for key, value in self.items():
208
            if key in PROTECTED_KEYS: continue
209
            kwargs[key] = value  # should I copy? deepcopy?
210
        return self.__class__(**kwargs)
211
212
    @property
213
    def data(self):
214
        return self[DATA]
215
216
    @property
217
    def tensor(self):
218
        return self.data
219
220
    @property
221
    def affine(self):
222
        return self[AFFINE]
223
224
    @property
225
    def type(self):
226
        return self[TYPE]
227
228
    @property
229
    def shape(self) -> Tuple[int, int, int, int]:
230
        return tuple(self.data.shape)
231
232
    @property
233
    def spatial_shape(self) -> TypeTripletInt:
234
        return self.shape[1:]
235
236
    @property
237
    def orientation(self):
238
        return nib.aff2axcodes(self.affine)
239
240
    @property
241
    def spacing(self):
242
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
243
        return tuple(spacing)
244
245
    @property
246
    def memory(self):
247
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
248
249
    def get_spacing_string(self):
250
        strings = [f'{n:.2f}' for n in self.spacing]
251
        string = f'({", ".join(strings)})'
252
        return string
253
254
    def get_bounds(self):
255
        """Get image bounds in mm."""
256
        first_index = 3 * (-0.5,)
257
        last_index = np.array(self.spatial_shape) - 0.5
258
        first_point = nib.affines.apply_affine(self.affine, first_index)
259
        last_point = nib.affines.apply_affine(self.affine, last_index)
260
        array = np.array((first_point, last_point))
261
        bounds_x, bounds_y, bounds_z = array.T.tolist()
262
        return bounds_x, bounds_y, bounds_z
263
264
    @staticmethod
265
    def _parse_path(path: TypePath) -> Path:
266
        if path is None:
267
            return None
268
        try:
269
            path = Path(path).expanduser()
270
        except TypeError:
271
            message = f'Conversion to path not possible for variable: {path}'
272
            raise TypeError(message)
273
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
274
            raise FileNotFoundError(f'File not found: {path}')
275
        return path
276
277
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
278
        if tensor is None:
279
            return None
280
        if isinstance(tensor, np.ndarray):
281
            tensor = torch.from_numpy(tensor)
282
        tensor = self.parse_tensor_shape(tensor)
283
        if self.check_nans and torch.isnan(tensor).any():
284
            warnings.warn(f'NaNs found in tensor')
285
        return tensor
286
287
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
288
        return ensure_4d(tensor, self.channels_last, self.num_spatial_dims)
289
290
    @staticmethod
291
    def parse_affine(affine: np.ndarray) -> np.ndarray:
292
        if affine is None:
293
            return np.eye(4)
294
        if not isinstance(affine, np.ndarray):
295
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
296
        if affine.shape != (4, 4):
297
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
298
        return affine
299
300
    def _load(self) -> Tuple[torch.Tensor, np.ndarray]:
301
        r"""Load the image from disk.
302
303
        Returns:
304
            Tuple containing a 4D tensor of size :math:`(C, D, H, W)` and a 2D
305
            :math:`4 \times 4` affine matrix to convert voxel indices to world
306
            coordinates.
307
        """
308
        if self._loaded:
309
            return
310
        tensor, affine = read_image(self.path)
311
        tensor = self.parse_tensor_shape(tensor)
312
313
        if self.check_nans and torch.isnan(tensor).any():
314
            warnings.warn(f'NaNs found in file "{self.path}"')
315
        self[DATA] = tensor
316
        self[AFFINE] = affine
317
        self._loaded = True
318
319
    def save(self, path, squeeze=True, channels_last=True):
320
        """Save image to disk.
321
322
        Args:
323
            path: String or instance of :py:class:`pathlib.Path`.
324
            squeeze: If ``True``, the singleton dimensions will be removed
325
                before saving.
326
            channels_last: If ``True``, the channels will be saved in the last
327
                dimension.
328
        """
329
        write_image(
330
            self[DATA],
331
            self[AFFINE],
332
            path,
333
            squeeze=squeeze,
334
            channels_last=channels_last,
335
        )
336
337
    def is_2d(self) -> bool:
338
        return self.shape[-3] == 1
339
340
    def numpy(self) -> np.ndarray:
341
        """Get a NumPy array containing the image data."""
342
        return np.asarray(self)
343
344
    def as_sitk(self) -> sitk.Image:
345
        """Get the image as an instance of :py:class:`sitk.Image`."""
346
        return nib_to_sitk(self[DATA], self[AFFINE])
347
348
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
349
        """Get image center in RAS+ or LPS+ coordinates.
350
351
        Args:
352
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
353
                the first dimension grows towards the left, etc. Otherwise, the
354
                coordinates will be in RAS+ orientation.
355
        """
356
        size = np.array(self.spatial_shape)
357
        center_index = (size - 1) / 2
358
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
359
        if lps:
360
            return (-r, -a, s)
361
        else:
362
            return (r, a, s)
363
364
    def set_check_nans(self, check_nans: bool):
365
        self.check_nans = check_nans
366
367
    def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
368
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
369
        new_affine = self.affine.copy()
370
        new_affine[:3, 3] = new_origin
371
        i0, j0, k0 = index_ini
372
        i1, j1, k1 = index_fin
373
        patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
374
        kwargs = dict(
375
            tensor=patch,
376
            affine=new_affine,
377
            type=self.type,
378
            path=self.path,
379
            channels_last=False,
380
        )
381
        for key, value in self.items():
382
            if key in PROTECTED_KEYS: continue
383
            kwargs[key] = value  # should I copy? deepcopy?
384
        return self.__class__(**kwargs)
385
386
387
class ScalarImage(Image):
388
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
389
390
    Example:
391
        >>> import torch
392
        >>> import torchio
393
        >>> image = torchio.ScalarImage('t1.nii.gz')  # loading from a file
394
        >>> image = torchio.ScalarImage(tensor=torch.rand(128, 128, 68))  # from tensor
395
        >>> data, affine = image.data, image.affine
396
        >>> affine.shape
397
        (4, 4)
398
        >>> image.data is image[torchio.DATA]
399
        True
400
        >>> image.data is image.tensor
401
        True
402
        >>> type(image.data)
403
        torch.Tensor
404
405
    See :py:class:`~torchio.Image` for more information.
406
407
    Raises:
408
        ValueError: A :py:attr:`type` is used for instantiation.
409
    """
410
    def __init__(self, *args, **kwargs):
411
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
412
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
413
        kwargs.update({'type': INTENSITY})
414
        super().__init__(*args, **kwargs)
415
416
417
class LabelMap(Image):
418
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
419
420
    Example:
421
        >>> import torch
422
        >>> import torchio
423
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
424
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
425
426
    See :py:class:`~torchio.data.image.Image` for more information.
427
428
    Raises:
429
        ValueError: If a value for :py:attr:`type` is given.
430
    """
431
    def __init__(self, *args, **kwargs):
432
        if 'type' in kwargs and kwargs['type'] != LABEL:
433
            raise ValueError('Type of LabelMap is always torchio.LABEL')
434
        kwargs.update({'type': LABEL})
435
        super().__init__(*args, **kwargs)
436