Passed
Push — master ( e83024...85c52f )
by Fernando
01:01
created

torchio.data.dataset.SubjectsDataset.dry_iter()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
import copy
2
from typing import Sequence, Optional, Callable, Iterable
3
4
from torch.utils.data import Dataset
5
6
from .subject import Subject
7
8
9
class SubjectsDataset(Dataset):
10
    """Base TorchIO dataset.
11
12
    Reader of 3D medical images that directly inherits from the PyTorch
13
    :class:`~torch.utils.data.Dataset`. It can be used with a PyTorch
14
    :class:`~torch.utils.data.DataLoader` for efficient loading and
15
    augmentation. It receives a list of instances of :class:`~torchio.Subject`
16
    and an optional transform applied to the volumes after loading.
17
18
    Args:
19
        subjects: List of instances of :class:`~torchio.Subject`.
20
        transform: An instance of :class:`~torchio.transforms.Transform`
21
            that will be applied to each subject.
22
        load_getitem: Load all subject images before returning it in
23
            :meth:`__getitem__`. Set it to ``False`` if some of the images will
24
            not be needed during training.
25
26
    Example:
27
        >>> import torchio as tio
28
        >>> subject_a = tio.Subject(
29
        ...     t1=tio.ScalarImage('t1.nrrd',),
30
        ...     t2=tio.ScalarImage('t2.mha',),
31
        ...     label=tio.LabelMap('t1_seg.nii.gz'),
32
        ...     age=31,
33
        ...     name='Fernando Perez',
34
        ... )
35
        >>> subject_b = tio.Subject(
36
        ...     t1=tio.ScalarImage('colin27_t1_tal_lin.minc',),
37
        ...     t2=tio.ScalarImage('colin27_t2_tal_lin_dicom',),
38
        ...     label=tio.LabelMap('colin27_seg1.nii.gz'),
39
        ...     age=56,
40
        ...     name='Colin Holmes',
41
        ... )
42
        >>> subjects_list = [subject_a, subject_b]
43
        >>> transforms = [
44
        ...     tio.RescaleIntensity((0, 1)),
45
        ...     tio.RandomAffine(),
46
        ... ]
47
        >>> transform = tio.Compose(transforms)
48
        >>> subjects_dataset = tio.SubjectsDataset(subjects_list, transform=transform)
49
        >>> subject = subjects_dataset[0]
50
51
    .. _NiBabel: https://nipy.org/nibabel/#nibabel
52
    .. _SimpleITK: https://itk.org/Wiki/ITK/FAQ#What_3D_file_formats_can_ITK_import_and_export.3F
53
    .. _DICOM: https://www.dicomstandard.org/
54
    .. _affine matrix: https://nipy.org/nibabel/coordinate_systems.html
55
56
    .. tip:: To quickly iterate over the subjects without loading the images,
57
        use :meth:`dry_iter()`.
58
    """  # noqa: E501
59
60
    def __init__(
61
            self,
62
            subjects: Sequence[Subject],
63
            transform: Optional[Callable] = None,
64
            load_getitem: bool = True,
65
            ):
66
        self._parse_subjects_list(subjects)
67
        self._subjects = subjects
68
        self._transform: Optional[Callable]
69
        self.set_transform(transform)
70
        self.load_getitem = load_getitem
71
72
    def __len__(self):
73
        return len(self._subjects)
74
75
    def __getitem__(self, index: int) -> Subject:
76
        if not isinstance(index, int):
77
            raise ValueError(f'Index "{index}" must be int, not {type(index)}')
78
        subject = self._subjects[index]
79
        subject = copy.deepcopy(subject)  # cheap since images not loaded yet
80
        if self.load_getitem:
81
            subject.load()
82
83
        # Apply transform (this is usually the bottleneck)
84
        if self._transform is not None:
85
            subject = self._transform(subject)
86
        return subject
87
88
    def dry_iter(self):
89
        """Return the internal list of subjects.
90
91
        This can be used to iterate over the subjects without loading the data
92
        and applying any transforms::
93
94
        >>> names = [subject.name for subject in dataset.dry_iter()]
95
        """
96
        return self._subjects
97
98
    def set_transform(self, transform: Optional[Callable]) -> None:
99
        """Set the :attr:`transform` attribute.
100
101
        Args:
102
            transform: Callable object, typically an subclass of
103
                :class:`torchio.transforms.Transform`.
104
        """
105
        if transform is not None and not callable(transform):
106
            message = (
107
                'The transform must be a callable object,'
108
                f' but it has type {type(transform)}'
109
            )
110
            raise ValueError(message)
111
        self._transform = transform
112
113
    @staticmethod
114
    def _parse_subjects_list(subjects_list: Iterable[Subject]) -> None:
115
        # Check that it's an iterable
116
        try:
117
            iter(subjects_list)
118
        except TypeError as e:
119
            message = (
120
                f'Subject list must be an iterable, not {type(subjects_list)}'
121
            )
122
            raise TypeError(message) from e
123
124
        # Check that it's not empty
125
        if not subjects_list:
126
            raise ValueError('Subjects list is empty')
127
128
        # Check each element
129
        for subject in subjects_list:
130
            if not isinstance(subject, Subject):
131
                message = (
132
                    'Subjects list must contain instances of torchio.Subject,'
133
                    f' not "{type(subject)}"'
134
                )
135
                raise TypeError(message)
136