Passed
Push — master ( 7bf0dc...387cc1 )
by Fernando
01:06
created

torchio.data.dataset.SubjectsDataset.__len__()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
import copy
2
import collections
3
from typing import Dict, Sequence, Optional, Callable
4
5
from deprecated import deprecated
6
from torch.utils.data import Dataset
7
8
from ..utils import get_stem
9
from ..torchio import DATA, AFFINE, TYPE, PATH, STEM, TypePath
10
from .image import Image
11
from .io import write_image
12
from .subject import Subject
13
14
15
class SubjectsDataset(Dataset):
16
    """Base TorchIO dataset.
17
18
    :py:class:`~torchio.data.dataset.SubjectsDataset`
19
    is a reader of 3D medical images that directly
20
    inherits from :class:`torch.utils.data.Dataset`.
21
    It can be used with a :class:`torch.utils.data.DataLoader`
22
    for efficient loading and augmentation.
23
    It receives a list of instances of
24
    :py:class:`torchio.data.subject.Subject`.
25
26
    Args:
27
        subjects: Sequence of instances of
28
            :class:`~torchio.data.subject.Subject`.
29
        transform: An instance of :py:class:`torchio.transforms.Transform`
30
            that will be applied to each sample.
31
32
    Example:
33
        >>> import torchio
34
        >>> from torchio import SubjectsDataset, ScalarImage, LabelMap, Subject
35
        >>> from torchio.transforms import RescaleIntensity, RandomAffine, Compose
36
        >>> subject_a = Subject([
37
        ...     t1=ScalarImage('t1.nrrd',),
38
        ...     t2=ScalarImage('t2.mha',),
39
        ...     label=LabelMap('t1_seg.nii.gz'),
40
        ...     age=31,
41
        ...     name='Fernando Perez',
42
        >>> ])
43
        >>> subject_b = Subject(
44
        ...     t1=ScalarImage('colin27_t1_tal_lin.minc',),
45
        ...     t2=ScalarImage('colin27_t2_tal_lin_dicom',),
46
        ...     label=LabelMap('colin27_seg1.nii.gz'),
47
        ...     age=56,
48
        ...     name='Colin Holmes',
49
        ... )
50
        >>> subjects_list = [subject_a, subject_b]
51
        >>> transforms = [
52
        ...     RescaleIntensity((0, 1)),
53
        ...     RandomAffine(),
54
        ... ]
55
        >>> transform = Compose(transforms)
56
        >>> subjects_dataset = SubjectsDataset(subjects_list, transform=transform)
57
        >>> subject_sample = subjects_dataset[0]
58
59
    .. _NiBabel: https://nipy.org/nibabel/#nibabel
60
    .. _SimpleITK: https://itk.org/Wiki/ITK/FAQ#What_3D_file_formats_can_ITK_import_and_export.3F
61
    .. _DICOM: https://www.dicomstandard.org/
62
    .. _affine matrix: https://nipy.org/nibabel/coordinate_systems.html
63
    """
64
65
    def __init__(
66
            self,
67
            subjects: Sequence[Subject],
68
            transform: Optional[Callable] = None,
69
            ):
70
        self._parse_subjects_list(subjects)
71
        self.subjects = subjects
72
        self._transform: Optional[Callable]
73
        self.set_transform(transform)
74
75
    def __len__(self):
76
        return len(self.subjects)
77
78
    def __getitem__(self, index: int) -> dict:
79
        if not isinstance(index, int):
80
            raise ValueError(f'Index "{index}" must be int, not {type(index)}')
81
        subject = self.subjects[index]
82
        sample = copy.deepcopy(subject)  # cheap since images not loaded yet
83
        sample.load()
84
85
        # Apply transform (this is usually the bottleneck)
86
        if self._transform is not None:
87
            sample = self._transform(sample)
88
        return sample
89
90
    def set_transform(self, transform: Optional[Callable]) -> None:
91
        """Set the :attr:`transform` attribute.
92
93
        Args:
94
            transform: An instance of :py:class:`torchio.transforms.Transform`.
95
        """
96
        if transform is not None and not callable(transform):
97
            raise ValueError(
98
                f'The transform must be a callable object, not {transform}')
99
        self._transform = transform
100
101
    @staticmethod
102
    def _parse_subjects_list(subjects_list: Sequence[Subject]) -> None:
103
        # Check that it's list or tuple
104
        if not isinstance(subjects_list, collections.abc.Sequence):
105
            raise TypeError(
106
                f'Subject list must be a sequence, not {type(subjects_list)}')
107
108
        # Check that it's not empty
109
        if not subjects_list:
110
            raise ValueError('Subjects list is empty')
111
112
        # Check each element
113
        for subject in subjects_list:
114
            if not isinstance(subject, Subject):
115
                message = (
116
                    'Subjects list must contain instances of torchio.Subject,'
117
                    f' not "{type(subject)}"'
118
                )
119
                raise TypeError(message)
120
121
    @classmethod
122
    @deprecated(
123
        'SubjectsDataset.save_sample is deprecated. Use Image.save instead'
124
    )
125
    def save_sample(
126
            cls,
127
            sample: Subject,
128
            output_paths_dict: Dict[str, TypePath],
129
            ) -> None:
130
        for key, output_path in output_paths_dict.items():
131
            tensor = sample[key][DATA]
132
            affine = sample[key][AFFINE]
133
            write_image(tensor, affine, output_path)
134
135
136
@deprecated('ImagesDataset is deprecated. Use SubjectsDataset instead.')
137
class ImagesDataset(SubjectsDataset):
138
    pass
139