Passed
Push — master ( 596012...87da33 )
by Fernando
01:30
created

torchio.data.images.ImagesDataset.save_sample()   A

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 10
rs 9.95
c 0
b 0
f 0
cc 2
nop 3
1
import copy
2
import pprint
3
import warnings
4
import collections
5
from pathlib import Path
6
from typing import (
7
    Any,
8
    Dict,
9
    List,
10
    Tuple,
11
    Sequence,
12
    Optional,
13
    Callable,
14
)
15
import torch
0 ignored issues
show
introduced by
Unable to import 'torch'
Loading history...
16
from torch.utils.data import Dataset
0 ignored issues
show
introduced by
Unable to import 'torch.utils.data'
Loading history...
17
import numpy as np
0 ignored issues
show
introduced by
Unable to import 'numpy'
Loading history...
18
from ..utils import get_stem
19
from ..torchio import DATA, AFFINE, TYPE, PATH, STEM, TypePath
20
from .io import read_image, write_image
21
22
23
class Image(dict):
24
    r"""Class to store information about an image.
25
26
    Args:
27
        path: Path to a file that can be read by
28
            :mod:`SimpleITK` or :mod:`nibabel` or to a directory containing
29
            DICOM files.
30
        type_: Type of image, such as :attr:`torchio.INTENSITY` or
31
            :attr:`torchio.LABEL`. This will be used by the transforms to
32
            decide whether to apply an operation, or which interpolation to use
33
            when resampling.
34
        **kwargs: Items that will be added to image dictionary within the
35
            subject sample.
36
    """
37
38
    def __init__(self, path: TypePath, type_: str, **kwargs: Dict[str, Any]):
39
        super().__init__(**kwargs)
40
        self.path = self._parse_path(path)
41
        self.type = type_
42
        self.is_sample = False  # set to True by ImagesDataset
43
44
    @staticmethod
45
    def _parse_path(path: TypePath) -> Path:
46
        try:
47
            path = Path(path).expanduser()
48
        except TypeError:
49
            message = f'Conversion to path not possible for variable: {path}'
50
            raise TypeError(message)
51
        if not (path.is_file() or path.is_dir()):  # might be a dir with DICOM
52
            raise FileNotFoundError(f'File not found: {path}')
53
        return path
54
55
    def load(self, check_nans: bool = True) -> Tuple[torch.Tensor, np.ndarray]:
56
        r"""Load the image from disk.
57
58
        The file is expected to be monomodal and 3D. A channels dimension is
59
        added to the tensor.
60
61
        Args:
62
            check_nans: If ``True``, issues a warning if NaNs are found
63
                in the image
64
65
        Returns:
66
            Tuple containing a 4D data tensor of size
67
            :math:`(1, D_{in}, H_{in}, W_{in})`
68
            and a 2D 4x4 affine matrix
69
        """
70
        tensor, affine = read_image(self.path)
71
        tensor = tensor.unsqueeze(0)  # add channels dimension
72
        if check_nans and torch.isnan(tensor).any():
73
            warnings.warn(f'NaNs found in file "{self.path}"')
74
        return tensor, affine
75
76
77
class Subject(dict):
78
    """Class to store information about the images corresponding to a subject.
79
80
    Args:
81
        *args: If provided, a dictionary of items.
82
        **kwargs: Items that will be added to the subject sample.
83
84
    Example:
85
86
        >>> import torchio
87
        >>> from torchio import Image, Subject
88
        >>> # One way:
89
        >>> subject = Subject(
90
        ...     one_image=Image('path_to_image.nii.gz, torchio.INTENSITY),
91
        ...     a_segmentation=Image('path_to_seg.nii.gz, torchio.LABEL),
92
        ...     age=45,
93
        ...     name='John Doe',
94
        ...     hospital='Hospital Juan Negrín',
95
        ... )
96
        >>> # If you want to create the mapping before, or have spaces in the keys:
97
        >>> subject_dict = {
98
        ...     'one image': Image('path_to_image.nii.gz, torchio.INTENSITY),
99
        ...     'a segmentation': Image('path_to_seg.nii.gz, torchio.LABEL),
100
        ...     'age': 45,
101
        ...     'name': 'John Doe',
102
        ...     'hospital': 'Hospital Juan Negrín',
103
        ... }
104
        >>> Subject(subject_dict)
105
106
    """
107
108
    def __init__(self, *args, **kwargs: Dict[str, Any]):
109
        if args:
110
            if len(args) == 1 and isinstance(args[0], dict):
111
                kwargs.update(args[0])
112
            else:
113
                message = (
114
                    'Only one dictionary as positional argument is allowed')
115
                raise ValueError(message)
116
        super().__init__(**kwargs)
117
        self.images = [
118
            (k, v) for (k, v) in self.items()
119
            if isinstance(v, Image)
120
        ]
121
        self._parse_images(self.images)
122
        self.is_sample = False  # set to True by ImagesDataset
123
124
    def __repr__(self):
125
        string = (
126
            f'{self.__class__.__name__}'
127
            f'(Keys: {tuple(self.keys())}; images: {len(self.images)})'
128
        )
129
        return string
130
131
    @staticmethod
132
    def _parse_images(images: List[Tuple[str, Image]]) -> None:
133
        # Check that it's not empty
134
        if not images:
135
            raise ValueError('A subject without images cannot be created')
136
137
    def check_consistent_shape(self) -> None:
138
        shapes_dict = {}
139
        for key, image in self.items():
140
            if not isinstance(image, Image) or not image.is_sample:
141
                continue
142
            shapes_dict[key] = image[DATA].shape
143
        num_unique_shapes = len(set(shapes_dict.values()))
144
        if num_unique_shapes > 1:
145
            message = (
146
                'Images in sample have inconsistent shapes:'
147
                f'\n{pprint.pformat(shapes_dict)}'
148
            )
149
            raise ValueError(message)
150
151
152
class ImagesDataset(Dataset):
153
    """Base TorchIO dataset.
154
155
    :class:`~torchio.data.images.ImagesDataset`
156
    is a reader of 3D medical images that directly
157
    inherits from :class:`torch.utils.data.Dataset`.
158
    It can be used with a :class:`torch.utils.data.DataLoader`
159
    for efficient loading and augmentation.
160
    It receives a list of subjects, where each subject is an instance of
161
    :class:`~torchio.data.images.Subject` containing instances of
162
    :class:`~torchio.data.images.Image`.
163
    The file format must be compatible with `NiBabel`_ or `SimpleITK`_ readers.
164
    It can also be a directory containing
165
    `DICOM`_ files.
166
167
    Indexing an :class:`~torchio.data.images.ImagesDataset` returns a
168
    Python dictionary with the data corresponding to the queried subject.
169
    The keys in the dictionary are the names of the images passed to that
170
    subject, for example ``('t1', 't2', 'segmentation')``.
171
172
    The value corresponding to each image name is another dictionary
173
    ``image_dict`` with information about the image.
174
    The data is stored in ``image_dict[torchio.IMAGE]``,
175
    and the corresponding `affine matrix`_
176
    is in ``image_dict[torchio.AFFINE]``:
177
178
        >>> sample = images_dataset[0]
179
        >>> sample.keys()
180
        dict_keys(['image', 'label'])
181
        >>> image_dict = sample['image']
182
        >>> image_dict[torchio.DATA].shape
183
        torch.Size([1, 176, 256, 256])
184
        >>> image_dict[torchio.AFFINE]
185
        array([[   0.03,    1.13,   -0.08,  -88.54],
186
               [   0.06,    0.08,    0.95, -129.66],
187
               [   1.18,   -0.06,   -0.11,  -67.15],
188
               [   0.  ,    0.  ,    0.  ,    1.  ]])
189
190
    Args:
191
        subjects: Sequence of instances of
192
            :class:`~torchio.data.images.Subject`.
193
        transform: An instance of :py:class:`torchio.transforms.Transform`
194
            that will be applied to each sample.
195
        check_nans: If ``True``, issues a warning if NaNs are found
196
            in the image.
197
        load_image_data: If ``False``, image data and affine will not be loaded.
198
            These fields will be set to ``None`` in the sample. This can be
199
            used to quickly iterate over the samples to retrieve e.g. the
200
            images paths. If ``True``, transform must be ``None``.
201
202
    .. _NiBabel: https://nipy.org/nibabel/#nibabel
203
    .. _SimpleITK: https://itk.org/Wiki/ITK/FAQ#What_3D_file_formats_can_ITK_import_and_export.3F
204
    .. _DICOM: https://www.dicomstandard.org/
205
    .. _affine matrix: https://nipy.org/nibabel/coordinate_systems.html
206
207
    """
208
209
    def __init__(
210
            self,
211
            subjects: Sequence[Subject],
212
            transform: Optional[Callable] = None,
213
            check_nans: bool = True,
214
            load_image_data: bool = True,
215
            ):
216
        self._parse_subjects_list(subjects)
217
        self.subjects = subjects
218
        self._transform: Optional[Callable]
219
        self.set_transform(transform)
220
        self.check_nans = check_nans
221
        self._load_image_data: bool
222
        self.set_load_image_data(load_image_data)
223
224
    def __len__(self):
225
        return len(self.subjects)
226
227
    def __getitem__(self, index: int) -> dict:
228
        if not isinstance(index, int):
229
            raise ValueError(f'Index "{index}" must be int, not {type(index)}')
230
        subject = self.subjects[index]
231
        sample = self.get_sample_dict_from_subject(subject)
232
233
        # Apply transform (this is usually the bottleneck)
234
        if self._transform is not None:
235
            sample = self._transform(sample)
236
        return sample
237
238
    def get_sample_dict_from_subject(self, subject: Subject):
239
        """Create a dictionary of dictionaries with subject information.
240
241
        Args:
242
            subject: Instance of :py:class:`~torchio.data.images.Subject`.
243
        """
244
        subject_sample = copy.deepcopy(subject)
245
        for (key, value) in subject.items():
246
            if isinstance(value, Image):
247
                subject_sample[key] = self.get_image_dict_from_image(value)
248
            else:
249
                subject_sample[key] = value
250
        subject_sample.is_sample = True
251
        return subject_sample
252
253
    def get_image_dict_from_image(self, image: Image):
254
        """Create a dictionary with image information.
255
256
        Args:
257
            image: Instance of :py:class:`~torchio.data.images.Image`.
258
259
        Return:
260
            Dictionary with keys
261
            :py:attr:`torchio.DATA`,
262
            :py:attr:`torchio.AFFINE`,
263
            :py:attr:`torchio.TYPE`,
264
            :py:attr:`torchio.PATH` and
265
            :py:attr:`torchio.STEM`.
266
        """
267
        if self._load_image_data:
268
            tensor, affine = image.load(check_nans=self.check_nans)
269
        else:
270
            tensor = affine = None
271
        image_dict = {
272
            DATA: tensor,
273
            AFFINE: affine,
274
            TYPE: image.type,
275
            PATH: str(image.path),
276
            STEM: get_stem(image.path),
277
        }
278
        image = copy.deepcopy(image)
279
        image.update(image_dict)
280
        image.is_sample = True
281
        return image
282
283
    def set_transform(self, transform: Optional[Callable]) -> None:
284
        """Set the :attr:`transform` attribute.
285
286
        Args:
287
            transform: An instance of :py:class:`torchio.transforms.Transform`.
288
        """
289
        if transform is not None and not callable(transform):
290
            raise ValueError(
291
                f'The transform must be a callable object, not {transform}')
292
        self._transform = transform
293
294
    @staticmethod
295
    def _parse_subjects_list(subjects_list: Sequence[Subject]) -> None:
296
        # Check that it's list or tuple
297
        if not isinstance(subjects_list, collections.abc.Sequence):
298
            raise TypeError(
299
                f'Subject list must be a sequence, not {type(subjects_list)}')
300
301
        # Check that it's not empty
302
        if not subjects_list:
303
            raise ValueError('Subjects list is empty')
304
305
        # Check each element
306
        for subject in subjects_list:
307
            if not isinstance(subject, Subject):
308
                message = (
309
                    'Subjects list must contain instances of torchio.Subject,'
310
                    f' not "{type(subject)}"'
311
                )
312
                raise TypeError(message)
313
314
    @classmethod
315
    def save_sample(
316
            cls,
317
            sample: Subject,
318
            output_paths_dict: Dict[str, TypePath],
319
            ) -> None:
320
        for key, output_path in output_paths_dict.items():
321
            tensor = sample[key][DATA][0]  # remove channels dim
322
            affine = sample[key][AFFINE]
323
            write_image(tensor, affine, output_path)
324
325
    def set_load_image_data(self, load_image_data: bool):
326
        if not load_image_data and self._transform is not None:
327
            message = (
328
                'Load data cannot be set to False if transform is not None.'
329
                f'Current transform is {self._transform}')
330
            raise ValueError(message)
331
        self._load_image_data = load_image_data
332