Passed
Pull Request — master (#204)
by Fernando
57s
created

torchio.data.dataset.ImagesDataset.set_load_image_data()   A

Complexity

Conditions 3

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 6
nop 2
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
import copy
2
import collections
3
from typing import Dict, Sequence, Optional, Callable
4
5
from deprecated import deprecated
6
from torch.utils.data import Dataset
7
8
from ..utils import get_stem
9
from ..torchio import DATA, AFFINE, TYPE, PATH, STEM, TypePath
10
from .image import Image
11
from .io import write_image
12
from .subject import Subject
13
14
15
class ImagesDataset(Dataset):
16
    """Base TorchIO dataset.
17
18
    :py:class:`~torchio.data.dataset.ImagesDataset`
19
    is a reader of 3D medical images that directly
20
    inherits from :class:`torch.utils.data.Dataset`.
21
    It can be used with a :class:`torch.utils.data.DataLoader`
22
    for efficient loading and augmentation.
23
    It receives a list of subjects, where each subject is an instance of
24
    :py:class:`torchio.data.subject.Subject` containing instances of
25
    :py:class:`torchio.data.image.Image`.
26
    The file format must be compatible with `NiBabel`_ or `SimpleITK`_ readers.
27
    It can also be a directory containing
28
    `DICOM`_ files.
29
30
    Indexing an :py:class:`~torchio.data.dataset.ImagesDataset` returns an
31
    instance of :py:class:`~torchio.data.subject.Subject`. Check out the
32
    documentation for both classes for usage examples.
33
34
    Example:
35
36
        >>> sample = images_dataset[0]
37
        >>> sample
38
        Subject(Keys: ('image', 'label'); images: 2)
39
        >>> image = sample['image']  # or sample.image
40
        >>> image.shape
41
        torch.Size([1, 176, 256, 256])
42
        >>> image.affine
43
        array([[   0.03,    1.13,   -0.08,  -88.54],
44
               [   0.06,    0.08,    0.95, -129.66],
45
               [   1.18,   -0.06,   -0.11,  -67.15],
46
               [   0.  ,    0.  ,    0.  ,    1.  ]])
47
48
    Args:
49
        subjects: Sequence of instances of
50
            :class:`~torchio.data.subject.Subject`.
51
        transform: An instance of :py:class:`torchio.transforms.Transform`
52
            that will be applied to each sample.
53
54
    Example:
55
        >>> import torchio
56
        >>> from torchio import ImagesDataset, Image, Subject
57
        >>> from torchio.transforms import RescaleIntensity, RandomAffine, Compose
58
        >>> subject_a = Subject([
59
        ...     t1=Image('~/Dropbox/MRI/t1.nrrd', type=torchio.INTENSITY),
60
        ...     t2=Image('~/Dropbox/MRI/t2.mha', type=torchio.INTENSITY),
61
        ...     label=Image('~/Dropbox/MRI/t1_seg.nii.gz', type=torchio.LABEL),
62
        ...     age=31,
63
        ...     name='Fernando Perez',
64
        >>> ])
65
        >>> subject_b = Subject(
66
        ...     t1=Image('/tmp/colin27_t1_tal_lin.minc', type=torchio.INTENSITY),
67
        ...     t2=Image('/tmp/colin27_t2_tal_lin_dicom', type=torchio.INTENSITY),
68
        ...     label=Image('/tmp/colin27_seg1.nii.gz', type=torchio.LABEL),
69
        ...     age=56,
70
        ...     name='Colin Holmes',
71
        ... )
72
        >>> subjects_list = [subject_a, subject_b]
73
        >>> transforms = [
74
        ...     RescaleIntensity((0, 1)),
75
        ...     RandomAffine(),
76
        ... ]
77
        >>> transform = Compose(transforms)
78
        >>> subjects_dataset = ImagesDataset(subjects_list, transform=transform)
79
        >>> subject_sample = subjects_dataset[0]
80
81
    .. _NiBabel: https://nipy.org/nibabel/#nibabel
82
    .. _SimpleITK: https://itk.org/Wiki/ITK/FAQ#What_3D_file_formats_can_ITK_import_and_export.3F
83
    .. _DICOM: https://www.dicomstandard.org/
84
    .. _affine matrix: https://nipy.org/nibabel/coordinate_systems.html
85
    """
86
87
    def __init__(
88
            self,
89
            subjects: Sequence[Subject],
90
            transform: Optional[Callable] = None,
91
            ):
92
        self._parse_subjects_list(subjects)
93
        self.subjects = subjects
94
        self._transform: Optional[Callable]
95
        self.set_transform(transform)
96
97
    def __len__(self):
98
        return len(self.subjects)
99
100
    def __getitem__(self, index: int) -> dict:
101
        if not isinstance(index, int):
102
            raise ValueError(f'Index "{index}" must be int, not {type(index)}')
103
        subject = self.subjects[index]
104
        sample = copy.deepcopy(subject)
105
106
        # Apply transform (this is usually the bottleneck)
107
        if self._transform is not None:
108
            sample = self._transform(sample)
109
        return sample
110
111
    def set_transform(self, transform: Optional[Callable]) -> None:
112
        """Set the :attr:`transform` attribute.
113
114
        Args:
115
            transform: An instance of :py:class:`torchio.transforms.Transform`.
116
        """
117
        if transform is not None and not callable(transform):
118
            raise ValueError(
119
                f'The transform must be a callable object, not {transform}')
120
        self._transform = transform
121
122
    @staticmethod
123
    def _parse_subjects_list(subjects_list: Sequence[Subject]) -> None:
124
        # Check that it's list or tuple
125
        if not isinstance(subjects_list, collections.abc.Sequence):
126
            raise TypeError(
127
                f'Subject list must be a sequence, not {type(subjects_list)}')
128
129
        # Check that it's not empty
130
        if not subjects_list:
131
            raise ValueError('Subjects list is empty')
132
133
        # Check each element
134
        for subject in subjects_list:
135
            if not isinstance(subject, Subject):
136
                message = (
137
                    'Subjects list must contain instances of torchio.Subject,'
138
                    f' not "{type(subject)}"'
139
                )
140
                raise TypeError(message)
141
142
    @classmethod
143
    @deprecated(
144
        'ImagesDataset.save_sample is deprecated. Use Image.save instead'
145
    )
146
    def save_sample(
147
            cls,
148
            sample: Subject,
149
            output_paths_dict: Dict[str, TypePath],
150
            ) -> None:
151
        for key, output_path in output_paths_dict.items():
152
            tensor = sample[key][DATA].squeeze()  # assume 2D if (1, 1, H, W)
153
            affine = sample[key][AFFINE]
154
            write_image(tensor, affine, output_path)
155