Passed
Pull Request — master (#246)
by Fernando
59s
created

torchio.data.image   F

Complexity

Total Complexity 60

Size/Duplication

Total Lines 376
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 202
dl 0
loc 376
rs 3.6
c 0
b 0
f 0
wmc 60

27 Methods

Rating   Name   Duplication   Size   Complexity  
A Image.parse_affine() 0 9 4
A Image.spatial_shape() 0 3 1
A Image.parse_tensor() 0 7 3
A Image.data() 0 3 1
A Image.get_center() 0 15 2
A Image.save() 0 14 3
A LabelMap.__init__() 0 5 3
A Image.memory() 0 3 1
A Image.__repr__() 0 15 2
A Image.type() 0 3 1
A Image.crop() 0 12 3
A Image._parse_path() 0 12 5
C Image.__init__() 0 35 9
A Image.load() 0 23 5
A Image.set_check_nans() 0 2 1
A Image.shape() 0 3 1
A Image.numpy() 0 3 1
A Image.__getitem__() 0 5 3
A Image.orientation() 0 3 1
A Image.is_2d() 0 2 1
A Image.tensor() 0 3 1
A Image.parse_tensor_shape() 0 2 1
A Image.spacing() 0 4 1
A Image.get_spacing_string() 0 4 1
A ScalarImage.__init__() 0 5 3
A Image.as_sitk() 0 3 1
A Image.affine() 0 3 1

How to fix   Complexity   

Complexity

Complex classes like torchio.data.image often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import warnings
2
from pathlib import Path
3
from typing import Any, Dict, Tuple, Optional
4
5
import torch
6
import humanize
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from ..utils import (
12
    nib_to_sitk,
13
    get_rotation_and_spacing_from_affine,
14
    get_stem,
15
    ensure_4d,
16
)
17
from ..torchio import (
18
    TypeData,
19
    TypePath,
20
    TypeTripletInt,
21
    TypeTripletFloat,
22
    DATA,
23
    TYPE,
24
    AFFINE,
25
    PATH,
26
    STEM,
27
    INTENSITY,
28
    LABEL,
29
)
30
from .io import read_image, write_image
31
32
33
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
34
35
36
class Image(dict):
37
    r"""TorchIO image.
38
39
    For information about medical image orientation, check out `NiBabel docs`_,
40
    the `3D Slicer wiki`_, `Graham Wideman's website`_, `FSL docs`_ or
41
    `SimpleITK docs`_.
42
43
    Args:
44
        path: Path to a file that can be read by
45
            :mod:`SimpleITK` or :mod:`nibabel` or to a directory containing
46
            DICOM files. If :py:attr:`tensor` is given, the data in
47
            :py:attr:`path` will not be read.    # TODO: document dims
48
        type: Type of image, such as :attr:`torchio.INTENSITY` or
49
            :attr:`torchio.LABEL`. This will be used by the transforms to
50
            decide whether to apply an operation, or which interpolation to use
51
            when resampling. For example, `preprocessing`_ and `augmentation`_
52
            intensity transforms will only be applied to images with type
53
            :attr:`torchio.INTENSITY`. Spatial transforms will be applied to
54
            all types, and nearest neighbor interpolation is always used to
55
            resample images with type :attr:`torchio.LABEL`.
56
            The type :attr:`torchio.SAMPLING_MAP` may be used with instances of
57
            :py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
58
        tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 3D
59
            :py:class:`torch.Tensor` or NumPy array with dimensions
60
            :math:`(D, H, W)`.   # TODO: document
61
        affine: If :attr:`path` is not given, :attr:`affine` must be a
62
            :math:`4 \times 4` NumPy array. If ``None``, :attr:`affine` is an
63
            identity matrix.
64
        check_nans: If ``True``, issues a warning if NaNs are found
65
            in the image. If ``False``, images will not be checked for the
66
            presence of NaNs.
67
        **kwargs: Items that will be added to image dictionary within the
68
            subject sample.
69
70
    Example:
71
        >>> import torch
72
        >>> import torchio
73
        >>> # Loading from a file
74
        >>> t1_image = torchio.Image('t1.nii.gz', type=torchio.INTENSITY)
75
        >>> label_image = torchio.Image('t1_seg.nii.gz', type=torchio.LABEL)
76
        >>> image = torchio.Image(tensor=torch.rand(3, 4, 5))
77
        >>> image = torchio.Image('safe_image.nrrd', check_nans=False)
78
        >>> data, affine = image.data, image.affine
79
        >>> affine.shape
80
        (4, 4)
81
        >>> image.data is image[torchio.DATA]
82
        True
83
        >>> image.data is image.tensor
84
        True
85
        >>> type(image.data)
86
        torch.Tensor
87
88
    TorchIO images are `lazy loaders`_, i.e. the data is only loaded from disk
89
    when needed.
90
91
    Example:
92
        >>> import torchio
93
        >>> image = torchio.Image('t1.nii.gz')
94
        >>> image  # not loaded yet
95
        Image(path: t1.nii.gz; type: intensity)
96
        >>> times_two = 2 * image.data  # data is loaded and cached here
97
        >>> image
98
        Image(shape: (1, 256, 256, 176); spacing: (1.00, 1.00, 1.00); orientation: PIR+; memory: 44.0 MiB; type: intensity)
99
        >>> image.save('doubled_image.nii.gz')
100
101
    .. _lazy loaders: https://en.wikipedia.org/wiki/Lazy_loading
102
    .. _preprocessing: https://torchio.readthedocs.io/transforms/preprocessing.html#intensity
103
    .. _augmentation: https://torchio.readthedocs.io/transforms/augmentation.html#intensity
104
    .. _NiBabel docs: https://nipy.org/nibabel/image_orientation.html
105
    .. _3D Slicer wiki: https://www.slicer.org/wiki/Coordinate_systems
106
    .. _FSL docs: https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained
107
    .. _SimpleITK docs: https://simpleitk.readthedocs.io/en/master/fundamentalConcepts.html
108
    .. _Graham Wideman's website: http://www.grahamwideman.com/gw/brain/orientation/orientterms.htm
109
    """
110
    def __init__(
111
            self,
112
            path: Optional[TypePath] = None,
113
            type: str = INTENSITY,
114
            tensor: Optional[TypeData] = None,
115
            affine: Optional[TypeData] = None,
116
            check_nans: bool = True,
117
            num_spatial_dims: Optional[int] = None,  # TODO: document
118
            channels_last: bool = False,  # TODO: document
119
            **kwargs: Dict[str, Any],
120
            ):
121
        if path is None and tensor is None:
122
            raise ValueError('A value for path or tensor must be given')
123
        self._loaded = False
124
        self.num_spatial_dims = num_spatial_dims
125
        self.channels_last = channels_last
126
        tensor = self.parse_tensor(tensor)
127
        affine = self.parse_affine(affine)
128
        if tensor is not None:
129
            if affine is None:
130
                affine = np.eye(4)
131
            self[DATA] = tensor
132
            self[AFFINE] = affine
133
            self._loaded = True
134
        for key in PROTECTED_KEYS:
135
            if key in kwargs:
136
                message = f'Key "{key}" is reserved. Use a different one'
137
                raise ValueError(message)
138
139
        super().__init__(**kwargs)
140
        self.path = self._parse_path(path)
141
        self[PATH] = '' if self.path is None else str(self.path)
142
        self[STEM] = '' if self.path is None else get_stem(self.path)
143
        self[TYPE] = type
144
        self.check_nans = check_nans
145
146
    def __repr__(self):
147
        properties = []
148
        if self._loaded:
149
            properties.extend([
150
                f'shape: {self.shape}',
151
                f'spacing: {self.get_spacing_string()}',
152
                f'orientation: {"".join(self.orientation)}+',
153
                f'memory: {humanize.naturalsize(self.memory, binary=True)}',
154
            ])
155
        else:
156
            properties.append(f'path: "{self.path}"')
157
        properties.append(f'type: {self.type}')
158
        properties = '; '.join(properties)
159
        string = f'{self.__class__.__name__}({properties})'
160
        return string
161
162
    def __getitem__(self, item):
163
        if item in (DATA, AFFINE):
164
            if item not in self:
165
                self.load()
166
        return super().__getitem__(item)
167
168
    @property
169
    def data(self):
170
        return self[DATA]
171
172
    @property
173
    def tensor(self):
174
        return self.data
175
176
    @property
177
    def affine(self):
178
        return self[AFFINE]
179
180
    @property
181
    def type(self):
182
        return self[TYPE]
183
184
    @property
185
    def shape(self) -> Tuple[int, int, int, int]:
186
        return tuple(self.data.shape)
187
188
    @property
189
    def spatial_shape(self) -> TypeTripletInt:
190
        return self.shape[1:]
191
192
    @property
193
    def orientation(self):
194
        return nib.aff2axcodes(self.affine)
195
196
    @property
197
    def spacing(self):
198
        _, spacing = get_rotation_and_spacing_from_affine(self.affine)
199
        return tuple(spacing)
200
201
    @property
202
    def memory(self):
203
        return np.prod(self.shape) * 4  # float32, i.e. 4 bytes per voxel
204
205
    def get_spacing_string(self):
206
        strings = [f'{n:.2f}' for n in self.spacing]
207
        string = f'({", ".join(strings)})'
208
        return string
209
210
    @staticmethod
211
    def _parse_path(path: TypePath) -> Path:
212
        if path is None:
213
            return None
214
        try:
215
            path = Path(path).expanduser()
216
        except TypeError:
217
            message = f'Conversion to path not possible for variable: {path}'
218
            raise TypeError(message)
219
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
220
            raise FileNotFoundError(f'File not found: {path}')
221
        return path
222
223
    def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
224
        if tensor is None:
225
            return None
226
        if isinstance(tensor, np.ndarray):
227
            tensor = torch.from_numpy(tensor)
228
        tensor = self.parse_tensor_shape(tensor)
229
        return tensor
230
231
    def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
232
        return ensure_4d(tensor, self.channels_last, self.num_spatial_dims)
233
234
    @staticmethod
235
    def parse_affine(affine: np.ndarray) -> np.ndarray:
236
        if affine is None:
237
            return np.eye(4)
238
        if not isinstance(affine, np.ndarray):
239
            raise TypeError(f'Affine must be a NumPy array, not {type(affine)}')
240
        if affine.shape != (4, 4):
241
            raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
242
        return affine
243
244
    def load(self) -> Tuple[torch.Tensor, np.ndarray]:
245
        r"""Load the image from disk.
246
247
        The file is expected to be monomodal/grayscale and 2D or 3D.
248
        A channels dimension is added to the tensor.
249
250
        Returns:
251
            Tuple containing a 4D tensor of size
252
            :math:`(C, D, H, W)`
253
            and a 2D :math:`4 \times 4` affine matrix
254
        """
255
        if self._loaded:
256
            return
257
        if self.path is None:
258
            return
259
        tensor, affine = read_image(self.path)
260
        tensor = self.parse_tensor_shape(tensor)
261
262
        if self.check_nans and torch.isnan(tensor).any():
263
            warnings.warn(f'NaNs found in file "{self.path}"')
264
        self[DATA] = tensor
265
        self[AFFINE] = affine
266
        self._loaded = True
267
268
    def save(self, path, squeeze=False, channels_last=False):
269
        """Save image to disk.
270
271
        Args:
272
            path: String or instance of :py:class:`pathlib.Path`.
273
        """
274
        # Currently not working well for 2D images
275
        tensor = self[DATA]
276
        if channels_last:
277
            tensor = tensor.permute(1, 2, 3, 0)  # (D, H, W, C)
278
        if squeeze:
279
            tensor = tensor.squeeze()
280
        affine = self[AFFINE]
281
        write_image(tensor, affine, path)
282
283
    def is_2d(self) -> bool:
284
        return self.shape[-3] == 1
285
286
    def numpy(self) -> np.ndarray:
287
        """Get a NumPy array containing the image data."""
288
        return self[DATA].numpy()
289
290
    def as_sitk(self) -> sitk.Image:
291
        """Get the image as an instance of :py:class:`sitk.Image`."""
292
        return nib_to_sitk(self[DATA], self[AFFINE])
293
294
    def get_center(self, lps: bool = False) -> TypeTripletFloat:
295
        """Get image center in RAS+ or LPS+ coordinates.
296
297
        Args:
298
            lps: If ``True``, the coordinates will be in LPS+ orientation, i.e.
299
                the first dimension grows towards the left, etc. Otherwise, the
300
                coordinates will be in RAS+ orientation.
301
        """
302
        size = np.array(self.spatial_shape)
303
        center_index = (size - 1) / 2
304
        r, a, s = nib.affines.apply_affine(self.affine, center_index)
305
        if lps:
306
            return (-r, -a, s)
307
        else:
308
            return (r, a, s)
309
310
    def set_check_nans(self, check_nans):
311
        self.check_nans = check_nans
312
313
    def crop(self, index_ini, index_fin):
314
        new_origin = nib.affines.apply_affine(self.affine, index_ini)
315
        new_affine = self.affine.copy()
316
        new_affine[:3, 3] = new_origin
317
        i0, j0, k0 = index_ini
318
        i1, j1, k1 = index_fin
319
        patch = self.data[0, i0:i1, j0:j1, k0:k1].clone()
320
        kwargs = dict(tensor=patch, affine=new_affine, type=self.type, path=self.path)
321
        for key, value in self.items():
322
            if key in PROTECTED_KEYS: continue
323
            kwargs[key] = value  # should I copy? deepcopy?
324
        return self.__class__(**kwargs)
325
326
327
class ScalarImage(Image):
328
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
329
330
    Example:
331
        >>> import torch
332
        >>> import torchio
333
        >>> image = torchio.ScalarImage('t1.nii.gz')  # loading from a file
334
        >>> image = torchio.ScalarImage(tensor=torch.rand(128, 128, 68))  # from tensor
335
        >>> data, affine = image.data, image.affine
336
        >>> affine.shape
337
        (4, 4)
338
        >>> image.data is image[torchio.DATA]
339
        True
340
        >>> image.data is image.tensor
341
        True
342
        >>> type(image.data)
343
        torch.Tensor
344
345
    See :py:class:`~torchio.Image` for more information.
346
347
    Raises:
348
        ValueError: A :py:attr:`type` is used for instantiation.
349
    """
350
    def __init__(self, *args, **kwargs):
351
        if 'type' in kwargs and kwargs['type'] != INTENSITY:
352
            raise ValueError('Type of ScalarImage is always torchio.INTENSITY')
353
        kwargs.update({'type': INTENSITY})
354
        super().__init__(*args, **kwargs)
355
356
357
class LabelMap(Image):
358
    """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.LABEL`.
359
360
    Example:
361
        >>> import torch
362
        >>> import torchio
363
        >>> labels = torchio.LabelMap(tensor=torch.rand(128, 128, 68) > 0.5)
364
        >>> labels = torchio.LabelMap('t1_seg.nii.gz')  # loading from a file
365
366
    See :py:class:`~torchio.data.image.Image` for more information.
367
368
    Raises:
369
        ValueError: If a value for :py:attr:`type` is given.
370
    """
371
    def __init__(self, *args, **kwargs):
372
        if 'type' in kwargs and kwargs['type'] != LABEL:
373
            raise ValueError('Type of LabelMap is always torchio.LABEL')
374
        kwargs.update({'type': LABEL})
375
        super().__init__(*args, **kwargs)
376