Passed
Pull Request — master (#246)
by Fernando
01:08
created

torchio.data.image.Image.spatial_shape()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import nib_to_sitk, get_rotation_and_spacing_from_affine, get_stem
12
from ..torchio import (
13
    TypeData,
14
    TypePath,
15
    TypeTripletInt,
16
    TypeTripletFloat,
17
    DATA,
18
    TYPE,
19
    AFFINE,
20
    PATH,
21
    STEM,
22
    INTENSITY,
23
    LABEL,
24
    REPO_URL,
25
)
26
from .io import read_image, write_image
27
28
29
class Image(dict):
30
31
    PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
32
33
    r"""TorchIO image.
34
35
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
36
    when needed.
37
38
    Example:
39
        >>> import torchio
40
        >>> image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
41
        >>> image  # not loaded yet
42
        Image(path: t1.nii.gz; type: intensity)
43
        >>> times_two = 2 * image.data  # data is loaded and cached here
44
        >>> image
45
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
46
        >>> image.save('doubled_image.nii.gz')
47
48
    For information about medical image orientation, check out `NiBabel docs`_,
49
    the `3D Slicer wiki`_, `Graham Wideman's website`_ or `FSL docs`_.
50
51
    Args:
52
        path: Path to a file that can be read by
53
            :mod:`SimpleITK` or :mod:`nibabel` or to a directory containing
54
            DICOM files. If :py:attr:`tensor` is given, the data in
55
            :py:attr:`path` will not be read.
56
        type: Type of image, such as :attr:`torchio.INTENSITY` or
57
            :attr:`torchio.LABEL`. This will be used by the transforms to
58
            decide whether to apply an operation, or which interpolation to use
59
            when resampling. For example, `preprocessing`_ and `augmentation`_
60
            intensity transforms will only be applied to images with type
61
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
62
            all types, and nearest neighbor interpolation is always used to
63
            resample images with type :attr:`torchio.LABEL`.
64
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
65
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
66
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 3D
67
            :py:class:`torch.Tensor` or NumPy array with dimensions
68
            :math:`(D, H, W)`.
69
        affine: If :attr:`path` is not given, :attr:`affine` must be a
70
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
71
            identity matrix.
72
        check_nans: If ``True``, issues a warning if NaNs are found
73
            in the image.
74
        **kwargs: Items that will be added to image dictionary within the
75
            subject sample.
76
77
    Example:
78
        >>> import torch
79
        >>> import torchio
80
        >>> # Loading from a file
81
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
82
        >>> # Also:
83
        >>> image = torchio.ScalarImage('t1.nii.gz')
84
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
85
        >>> # Also:
86
        >>> label_image = torchio.LabelMap('t1_seg.nii.gz')
87
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
88
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
89
        >>> data, affine = image.data, image.affine
90
        >>> affine.shape
91
        (4, 4)
92
        >>> image.data is image[torchio.DATA]
93
        True
94
        >>> image.data is image.tensor
95
        True
96
        >>> type(image.data)
97
        torch.Tensor
98
99
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
100
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
101
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
102
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
103
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
104
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
105
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
106
107
    """
108
    def __init__(
109
            self,
110
            path: Optional[TypePath] = None,
111
            type: str = INTENSITY,
112
            tensor: Optional[TypeData] = None,
113
            affine: Optional[TypeData] = None,
114
            check_nans: bool = True,
115
            spatial_dimensions: Optional[int] = None,
116
            **kwargs: Dict[str, Any],
117
            ):
118
        if path is None and tensor is None:
119
            raise ValueError('A value for path or tensor must be given')
120
        self._loaded = False
121
        self.spatial_dimensions = spatial_dimensions
122
        tensor = self.parse_tensor(tensor)
123
        affine = self.parse_affine(affine)
124
        if tensor is not None:
125
            if affine is None:
126
                affine = np.eye(4)
127
            self[DATA] = tensor
128
            self[AFFINE] = affine
129
            self._loaded = True
130
        for key in self.PROTECTED_KEYS:
131
            if key in kwargs:
132
                message = f'Key "{key}" is reserved. Use a different one'
133
                raise ValueError(message)
134
135
        super().__init__(**kwargs)
136
        self.path = self._parse_path(path)
137
        self[PATH] = '' if self.path is None else str(self.path)
138
        self[STEM] = '' if self.path is None else get_stem(self.path)
139
        self[TYPE] = type
140
        self.check_nans = check_nans
141
142
    def __repr__(self):
143
        properties = []
144
        if self._loaded:
145
            properties.extend([
146
                f'shape: {self.shape}',
147
                f'spacing: {self.get_spacing_string()}',
148
                f'orientation: {"".join(self.orientation)}+',
149
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
150
            ])
151
        else:
152
            properties.append(f'path: "{self.path}"')
153
        properties.append(f'type: {self.type}')
154
        properties = '; '.join(properties)
155
        string = f'{self.__class__.__name__}({properties})'
156
        return string
157
158
    def __getitem__(self, item):
159
        if item in (DATA, AFFINE):
160
            if item not in self:
161
                self.load()
162
        return super().__getitem__(item)
163
164
    @property
165
    def data(self):
166
        return self[DATA]
167
168
    @property
169
    def tensor(self):
170
        return self.data
171
172
    @property
173
    def affine(self):
174
        return self[AFFINE]
175
176
    @property
177
    def type(self):
178
        return self[TYPE]
179
180
    @property
181
    def shape(self) -> Tuple[int, int, int, int]:
182
        return tuple(self.data.shape)
183
184
    @property
185
    def spatial_shape(self) -> TypeTripletInt:
186
        return self.shape[1:]
187
188
    @property
189
    def orientation(self):
190
        return nib.aff2axcodes(self.affine)
191
192
    @property
193
    def spacing(self):
194
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
195
        return tuple(spacing)
196
197
    @property
198
    def memory(self):
199
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
200
201
    def get_spacing_string(self):
202
        strings = [f'{n:.2f}' for n in self.spacing]
203
        string = f'({", ".join(strings)})'
204
        return string
205
206
    @staticmethod
207
    def _parse_path(path: TypePath) -> Path:
208
        if path is None:
209
            return None
210
        try:
211
            path = Path(path).expanduser()
212
        except TypeError:
213
            message = f'Conversion to path not possible for variable: {path}'
214
            raise TypeError(message)
215
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
216
            raise FileNotFoundError(f'File not found: {path}')
217
        return path
218
219
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
220
        if tensor is None:
221
            return None
222
        if isinstance(tensor, np.ndarray):
223
            tensor = torch.from_numpy(tensor)
224
        tensor = self.parse_tensor_shape(tensor)
225
        return tensor
226
227
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
228
        num_dimensions = tensor.ndim
229
        if num_dimensions == 4:  # assume 3D multichannel (C, D, H, W)
230
            pass
231
        elif num_dimensions == 2:  # assume 2D monochannel (1, 1, H, W)
232
            tensor = tensor[None, None]
233
        elif num_dimensions == 3:  # 2D multichannel or 3D monochannel?
234
            if self.spatial_dimensions == 2:
235
                # Assume (H, W, C)
236
                tensor = tensor.permute(2, 0, 1)[None]  # (C, 1, H, W)
237
            elif self.spatial_dimensions == 3:
238
                tensor = tensor[None]  # (1, D, H, W)
239
            else:  # try to guess
240
                shape = tensor.shape
241
                maybe_rgb = 3 in shape
242
                if maybe_rgb:
243
                    if shape[-1] == 3:  # (H, W, 3)
244
                        tensor = tensor.permute(2, 0, 1)[None]  # (3, 1, H, W)
245
                    elif shape[0] == 3:  # (3, H, W)
246
                        tensor = tensor.unsqueeze(1)  # (3, 1, H, W)
247
                else:  # (D, H, W)
248
                    tensor = tensor[None]  # (1, D, H, W)
249
        else:
250
            message = (
251
                f'{num_dimensions}D images not supported yet. Please create an'
252
                f' issue in {REPO_URL} if you would like support for them'
253
            )
254
            raise ValueError(message)
255
        assert tensor.ndim == 4
256
        return tensor
257
258
    @staticmethod
259
    def parse_affine(affine: np.ndarray) -> np.ndarray:
260
        if affine is None:
261
            return np.eye(4)
262
        if not isinstance(affine, np.ndarray):
263
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
264
        if affine.shape != (4, 4):
265
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
266
        return affine
267
268
    def load(self) -> Tuple[torch.Tensor, np.ndarray]:
269
        r"""Load the image from disk.
270
271
        The file is expected to be monomodal/grayscale and 2D or 3D.
272
        A channels dimension is added to the tensor.
273
274
        Returns:
275
            Tuple containing a 4D data tensor of size
276
            :math:`(1, D_{in}, H_{in}, W_{in})`
277
            and a 2D 4x4 affine matrix
278
        """
279
        if self._loaded:
280
            return
281
        if self.path is None:
282
            return
283
        tensor, affine = read_image(self.path)
284
        tensor = self.parse_tensor_shape(tensor)
285
286
        if self.check_nans and torch.isnan(tensor).any():
287
            warnings.warn(f'NaNs found in file "{self.path}"')
288
        self[DATA] = tensor
289
        self[AFFINE] = affine
290
        self._loaded = True
291
292
    def save(self, path):
293
        """Save image to disk.
294
295
        Args:
296
            path: String or instance of :py:class:`pathlib.Path`.
297
        """
298
        tensor = self[DATA].squeeze()  # assume 2D if (1, 1, H, W)
299
        affine = self[AFFINE]
300
        write_image(tensor, affine, path)
301
302
    def is_2d(self) -> bool:
303
        return self.shape[-3] == 1
304
305
    def numpy(self) -> np.ndarray:
306
        """Get a NumPy array containing the image data."""
307
        return self[DATA].numpy()
308
309
    def as_sitk(self) -> sitk.Image:
310
        """Get the image as an instance of :py:class:`sitk.Image`."""
311
        return nib_to_sitk(self[DATA][0], self[AFFINE])
312
313
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
314
        """Get image center in RAS+ or LPS+ coordinates.
315
316
        Args:
317
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
318
                the first dimension grows towards the left, etc. Otherwise, the
319
                coordinates will be in RAS+ orientation.
320
        """
321
        image = self.as_sitk()
322
        size = np.array(image.GetSize())
323
        center_index = (size - 1) / 2
324
        l, p, s = image.TransformContinuousIndexToPhysicalPoint(center_index)
325
        if lps:
326
            return (l, p, s)
327
        else:
328
            return (-l, -p, s)
329
330
    def set_check_nans(self, check_nans):
331
        self.check_nans = check_nans
332
333
    def crop(self, index_ini, index_fin):
334
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
335
        new_affine = self.affine.copy()
336
        new_affine[:3, 3] = new_origin
337
        i0, j0, k0 = index_ini
338
        i1, j1, k1 = index_fin
339
        patch = self.data[0, i0:i1, j0:j1, k0:k1].clone()
340
        kwargs = dict(tensor=patch, affine=new_affine, type=self.type, path=self.path)
341
        for key, value in self.items():
342
            if key in self.PROTECTED_KEYS: continue
343
            kwargs[key] = value  # should I copy? deepcopy?
344
        return self.__class__(**kwargs)
345
346
347
class ScalarImage(Image):
348
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
349
350
    See :py:class:`~torchio.Image` for more information.
351
352
    Raises:
353
        ValueError: A :py:attr:`type` is used for instantiation.
354
    """
355
    def __init__(self, *args, **kwargs):
356
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
357
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
358
        kwargs.update({'type': INTENSITY})
359
        super().__init__(*args, **kwargs)
360
361
362
class LabelMap(Image):
363
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
364
365
    See :py:class:`~torchio.Image` for more information.
366
367
    Raises:
368
        ValueError: A :py:attr:`type` is used for instantiation.
369
    """
370
    def __init__(self, *args, **kwargs):
371
        if 'type' in kwargs and kwargs['type'] != LABEL:
372
            raise ValueError('Type of LabelMap is always torchio.LABEL')
373
        kwargs.update({'type': LABEL})
374
        super().__init__(*args, **kwargs)
375