Passed
Push — master ( 79d509...08948f )
by Fernando
01:20
created

torchio.data.dataset.SubjectsDataset.save_sample()   A

Complexity

Conditions 2

Size

Total Lines 13
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 11
nop 3
dl 0
loc 13
rs 9.85
c 0
b 0
f 0
1
import copy
2
import collections
3
from typing import Sequence, Optional, Callable
4
5
from torch.utils.data import Dataset
6
7
from .subject import Subject
8
9
10
class SubjectsDataset(Dataset):
11
    """Base TorchIO dataset.
12
13
    :class:`~torchio.data.dataset.SubjectsDataset`
14
    is a reader of 3D medical images that directly
15
    inherits from :class:`torch.utils.data.Dataset`.
16
    It can be used with a :class:`torch.utils.data.DataLoader`
17
    for efficient loading and augmentation.
18
    It receives a list of instances of
19
    :class:`torchio.data.subject.Subject`.
20
21
    Args:
22
        subjects: List of instances of
23
            :class:`~torchio.data.subject.Subject`.
24
        transform: An instance of :class:`torchio.transforms.Transform`
25
            that will be applied to each subject.
26
27
    Example:
28
        >>> import torchio as tio
29
        >>> subject_a = tio.Subject(
30
        ...     t1=tio.ScalarImage('t1.nrrd',),
31
        ...     t2=tio.ScalarImage('t2.mha',),
32
        ...     label=tio.LabelMap('t1_seg.nii.gz'),
33
        ...     age=31,
34
        ...     name='Fernando Perez',
35
        ... )
36
        >>> subject_b = tio.Subject(
37
        ...     t1=tio.ScalarImage('colin27_t1_tal_lin.minc',),
38
        ...     t2=tio.ScalarImage('colin27_t2_tal_lin_dicom',),
39
        ...     label=tio.LabelMap('colin27_seg1.nii.gz'),
40
        ...     age=56,
41
        ...     name='Colin Holmes',
42
        ... )
43
        >>> subjects_list = [subject_a, subject_b]
44
        >>> transforms = [
45
        ...     tio.RescaleIntensity((0, 1)),
46
        ...     tio.RandomAffine(),
47
        ... ]
48
        >>> transform = tio.Compose(transforms)
49
        >>> subjects_dataset = tio.SubjectsDataset(subjects_list, transform=transform)
50
        >>> subject = subjects_dataset[0]
51
52
    .. _NiBabel: https://nipy.org/nibabel/#nibabel
53
    .. _SimpleITK: https://itk.org/Wiki/ITK/FAQ#What_3D_file_formats_can_ITK_import_and_export.3F
54
    .. _DICOM: https://www.dicomstandard.org/
55
    .. _affine matrix: https://nipy.org/nibabel/coordinate_systems.html
56
    """
57
58
    def __init__(
59
            self,
60
            subjects: Sequence[Subject],
61
            transform: Optional[Callable] = None,
62
            ):
63
        self._parse_subjects_list(subjects)
64
        self.subjects = subjects
65
        self._transform: Optional[Callable]
66
        self.set_transform(transform)
67
68
    def __len__(self):
69
        return len(self.subjects)
70
71
    def __getitem__(self, index: int) -> Subject:
72
        if not isinstance(index, int):
73
            raise ValueError(f'Index "{index}" must be int, not {type(index)}')
74
        subject = self.subjects[index]
75
        subject = copy.deepcopy(subject)  # cheap since images not loaded yet
76
        subject.load()
77
78
        # Apply transform (this is usually the bottleneck)
79
        if self._transform is not None:
80
            subject = self._transform(subject)
81
        return subject
82
83
    def set_transform(self, transform: Optional[Callable]) -> None:
84
        """Set the :attr:`transform` attribute.
85
86
        Args:
87
            transform: An instance of :class:`torchio.transforms.Transform`.
88
        """
89
        if transform is not None and not callable(transform):
90
            raise ValueError(
91
                f'The transform must be a callable object, not {transform}')
92
        self._transform = transform
93
94
    @staticmethod
95
    def _parse_subjects_list(subjects_list: Sequence[Subject]) -> None:
96
        # Check that it's list or tuple
97
        if not isinstance(subjects_list, collections.abc.Sequence):
98
            raise TypeError(
99
                f'Subject list must be a sequence, not {type(subjects_list)}')
100
101
        # Check that it's not empty
102
        if not subjects_list:
103
            raise ValueError('Subjects list is empty')
104
105
        # Check each element
106
        for subject in subjects_list:
107
            if not isinstance(subject, Subject):
108
                message = (
109
                    'Subjects list must contain instances of torchio.Subject,'
110
                    f' not "{type(subject)}"'
111
                )
112
                raise TypeError(message)
113