Passed
Push — master ( 42a45f...85e78d )
by Fernando
01:44
created

torchio.data.image.ScalarImage.__init__()   A

Complexity

Conditions 2

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 4
nop 3
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import nib_to_sitk, get_rotation_and_spacing_from_affine, get_stem
12
from ..torchio import (
13
    TypeData,
14
    TypePath,
15
    TypeTripletInt,
16
    TypeTripletFloat,
17
    DATA,
18
    TYPE,
19
    AFFINE,
20
    PATH,
21
    STEM,
22
    INTENSITY,
23
    LABEL,
24
)
25
from .io import read_image, write_image
26
27
28
class Image(dict):
29
    r"""TorchIO image.
30
31
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
32
    when needed.
33
34
    Example:
35
        >>> import torchio
36
        >>> image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
37
        >>> image  # not loaded yet
38
        Image(path: t1.nii.gz; type: intensity)
39
        >>> times_two = 2 * image.data  # data is loaded and cached here
40
        >>> image
41
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
42
        >>> image.save('doubled_image.nii.gz')
43
44
    For information about medical image orientation, check out `NiBabel docs`_,
45
    the `3D Slicer wiki`_, `Graham Wideman's website`_ or `FSL docs`_.
46
47
    Args:
48
        path: Path to a file that can be read by
49
            :mod:`SimpleITK` or :mod:`nibabel` or to a directory containing
50
            DICOM files.
51
        type: Type of image, such as :attr:`torchio.INTENSITY` or
52
            :attr:`torchio.LABEL`. This will be used by the transforms to
53
            decide whether to apply an operation, or which interpolation to use
54
            when resampling. For example, `preprocessing`_ and `augmentation`_
55
            intensity transforms will only be applied to images with type
56
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
57
            all types, and nearest neighbor interpolation is always used to
58
            resample images with type :attr:`torchio.LABEL`.
59
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
60
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
61
        tensor: If :attr:`path` is not given, :attr:`tensor` must be a 3D
62
            :py:class:`torch.Tensor` or NumPy array with dimensions
63
            :math:`(D, H, W)`.
64
        affine: If :attr:`path` is not given, :attr:`affine` must be a
65
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
66
            identity matrix.
67
        check_nans: If ``True``, issues a warning if NaNs are found
68
            in the image.
69
        **kwargs: Items that will be added to image dictionary within the
70
            subject sample.
71
72
    Example:
73
        >>> import torch
74
        >>> import torchio
75
        >>> # Loading from a file
76
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
77
        >>> # Also:
78
        >>> image = torchio.ScalarImage('t1.nii.gz')
79
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
80
        >>> # Also:
81
        >>> label_image = torchio.LabelMap('t1_seg.nii.gz')
82
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
83
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
84
        >>> data, affine = image.data, image.affine
85
        >>> affine.shape
86
        (4, 4)
87
        >>> image.data is image[torchio.DATA]
88
        True
89
        >>> image.data is image.tensor
90
        True
91
        >>> type(image.data)
92
        torch.Tensor
93
94
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
95
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
96
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
97
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
98
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
99
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
100
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
101
102
    """
103
    def __init__(
104
            self,
105
            path: Optional[TypePath] = None,
106
            type: str = INTENSITY,
107
            tensor: Optional[TypeData] = None,
108
            affine: Optional[TypeData] = None,
109
            check_nans: bool = True,
110
            **kwargs: Dict[str, Any],
111
            ):
112
        if path is None and tensor is None:
113
            raise ValueError('A value for path or tensor must be given')
114
        if path is not None:
115
            if tensor is not None or affine is not None:
116
                message = 'If a path is given, tensor and affine must be None'
117
                raise ValueError(message)
118
        self._loaded = False
119
        tensor = self.parse_tensor(tensor)
120
        affine = self.parse_affine(affine)
121
        if tensor is not None:
122
            if affine is None:
123
                affine = np.eye(4)
124
            self[DATA] = tensor
125
            self[AFFINE] = affine
126
            self._loaded = True
127
        for key in (DATA, AFFINE, TYPE, PATH, STEM):
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable AFFINE does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable TYPE does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable PATH does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable STEM does not seem to be defined.
Loading history...
128
            if key in kwargs:
129
                message = f'Key "{key}" is reserved. Use a different one'
130
                raise ValueError(message)
131
132
        super().__init__(**kwargs)
133
        self.path = self._parse_path(path)
134
        self[PATH] = '' if self.path is None else str(self.path)
135
        self[STEM] = '' if self.path is None else get_stem(self.path)
136
        self[TYPE] = type
137
        self.check_nans = check_nans
138
139
    def __repr__(self):
140
        properties = []
141
        if self._loaded:
142
            properties.extend([
143
                f'shape: {self.shape}',
144
                f'spacing: {self.get_spacing_string()}',
145
                f'orientation: {"".join(self.orientation)}+',
146
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
147
            ])
148
        else:
149
            properties.append(f'path: "{self.path}"')
150
        properties.append(f'type: {self.type}')
151
        properties = '; '.join(properties)
152
        string = f'{self.__class__.__name__}({properties})'
153
        return string
154
155
    def __getitem__(self, item):
156
        if item in (DATA, AFFINE):
157
            if item not in self:
158
                self.load()
159
        return super().__getitem__(item)
160
161
    @property
162
    def data(self):
163
        return self[DATA]
164
165
    @property
166
    def tensor(self):
167
        return self.data
168
169
    @property
170
    def affine(self):
171
        return self[AFFINE]
172
173
    @property
174
    def type(self):
175
        return self[TYPE]
176
177
    @property
178
    def shape(self) -> Tuple[int, int, int, int]:
179
        return tuple(self.data.shape)
180
181
    @property
182
    def spatial_shape(self) -> TypeTripletInt:
183
        return self.shape[1:]
184
185
    @property
186
    def orientation(self):
187
        return nib.aff2axcodes(self.affine)
188
189
    @property
190
    def spacing(self):
191
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
192
        return tuple(spacing)
193
194
    @property
195
    def memory(self):
196
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
197
198
    def get_spacing_string(self):
199
        strings = [f'{n:.2f}' for n in self.spacing]
200
        string = f'({", ".join(strings)})'
201
        return string
202
203
    @staticmethod
204
    def _parse_path(path: TypePath) -> Path:
205
        if path is None:
206
            return None
207
        try:
208
            path = Path(path).expanduser()
209
        except TypeError:
210
            message = f'Conversion to path not possible for variable: {path}'
211
            raise TypeError(message)
212
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
213
            raise FileNotFoundError(f'File not found: {path}')
214
        return path
215
216
    @staticmethod
217
    def parse_tensor(tensor: TypeData) -> torch.Tensor:
218
        if tensor is None:
219
            return None
220
        if isinstance(tensor, np.ndarray):
221
            tensor = torch.from_numpy(tensor)
222
        num_dimensions = tensor.dim()
223
        if num_dimensions != 3:
224
            message = (
225
                'The input tensor must have 3 dimensions (D, H, W),'
226
                f' but has {num_dimensions}: {tensor.shape}'
227
            )
228
            raise RuntimeError(message)
229
        tensor = tensor.unsqueeze(0)  # add channels dimension
230
        tensor = tensor.float()
231
        return tensor
232
233
    @staticmethod
234
    def parse_affine(affine: np.ndarray) -> np.ndarray:
235
        if affine is None:
236
            return np.eye(4)
237
        if not isinstance(affine, np.ndarray):
238
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
239
        if affine.shape != (4, 4):
240
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
241
        return affine
242
243
    def load(self) -> Tuple[torch.Tensor, np.ndarray]:
244
        r"""Load the image from disk.
245
246
        The file is expected to be monomodal/grayscale and 2D or 3D.
247
        A channels dimension is added to the tensor.
248
249
        Returns:
250
            Tuple containing a 4D data tensor of size
251
            :math:`(1, D_{in}, H_{in}, W_{in})`
252
            and a 2D 4x4 affine matrix
253
        """
254
        if self._loaded:
255
            return
256
        if self.path is None:
257
            return
258
        tensor, affine = read_image(self.path)
259
        # https://github.com/pytorch/pytorch/issues/9410#issuecomment-404968513
260
        tensor = tensor[(None,) * (3 - tensor.ndim)]  # force to be 3D
261
        # Remove next line and uncomment the two following ones once/if this issue
262
        # gets fixed:
263
        # https://github.com/pytorch/pytorch/issues/29010
264
        # See also https://discuss.pytorch.org/t/collating-named-tensors/78650/4
265
        tensor = tensor.unsqueeze(0)  # add channels dimension
266
        # name_dimensions(tensor, affine)
267
        # tensor = tensor.align_to('channels', ...)
268
        if self.check_nans and torch.isnan(tensor).any():
269
            warnings.warn(f'NaNs found in file "{self.path}"')
270
        self[DATA] = tensor
271
        self[AFFINE] = affine
272
        self._loaded = True
273
274
    def save(self, path):
275
        """Save image to disk.
276
277
        Args:
278
            path: String or instance of :py:class:`pathlib.Path`.
279
        """
280
        tensor = self[DATA].squeeze()  # assume 2D if (1, 1, H, W)
281
        affine = self[AFFINE]
282
        write_image(tensor, affine, path)
283
284
    def is_2d(self) -> bool:
285
        return self.shape[-3] == 1
286
287
    def numpy(self) -> np.ndarray:
288
        """Get a NumPy array containing the image data."""
289
        return self[DATA].numpy()
290
291
    def as_sitk(self) -> sitk.Image:
292
        """Get the image as an instance of :py:class:`sitk.Image`."""
293
        return nib_to_sitk(self[DATA][0], self[AFFINE])
294
295
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
296
        """Get image center in RAS+ or LPS+ coordinates.
297
298
        Args:
299
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
300
                the first dimension grows towards the left, etc. Otherwise, the
301
                coordinates will be in RAS+ orientation.
302
        """
303
        image = self.as_sitk()
304
        size = np.array(image.GetSize())
305
        center_index = (size - 1) / 2
306
        l, p, s = image.TransformContinuousIndexToPhysicalPoint(center_index)
307
        if lps:
308
            return (l, p, s)
309
        else:
310
            return (-l, -p, s)
311
312
    def set_check_nans(self, check_nans):
313
        self.check_nans = check_nans
314
315
    def crop(self, index_ini, index_fin):
316
        # TODO: add the rest of kwargs
317
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
318
        new_affine = self.affine.copy()
319
        new_affine[:3, 3] = new_origin
320
        i0, j0, k0 = index_ini
321
        i1, j1, k1 = index_fin
322
        patch = self.data[0, i0:i1, j0:j1, k0:k1].clone()
323
        return self.__class__(tensor=patch, affine=new_affine, type=self.type)
324
325
326
class ScalarImage(Image):
327
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
328
329
    See :py:class:`~torchio.Image` for more information.
330
331
    Raises:
332
        ValueError: A :py:attr:`type` is used for instantiation.
333
    """
334
    def __init__(self, *args, **kwargs):
335
        if 'type' in kwargs:
336
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
337
        super().__init__(*args, **kwargs, type=INTENSITY)
338
339
340
class LabelMap(Image):
341
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
342
343
    See :py:class:`~torchio.Image` for more information.
344
345
    Raises:
346
        ValueError: A :py:attr:`type` is used for instantiation.
347
    """
348
    def __init__(self, *args, **kwargs):
349
        if 'type' in kwargs:
350
            raise ValueError('Type of LabelMap is always torchio.LABEL')
351
        super().__init__(*args, **kwargs, type=LABEL)
352