Passed
Push — master ( b06930...3ddbe5 )
by Fernando
03:55
created

torchio.data.queue.Queue.get_next_subject()   A

Complexity

Conditions 2

Size

Total Lines 9
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 8
nop 1
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
import random
2
import warnings
3
from itertools import islice
4
from typing import List, Iterator
5
6
from tqdm import trange
7
from torch.utils.data import Dataset, DataLoader
8
9
from .subject import Subject
10
from .sampler import PatchSampler
11
from .dataset import SubjectsDataset
12
13
14
class Queue(Dataset):
15
    r"""Patches queue used for patch-based training.
16
17
    Args:
18
        subjects_dataset: Instance of
19
            :class:`~torchio.data.dataset.SubjectsDataset`.
20
        max_length: Maximum number of patches that can be stored in the queue.
21
            Using a large number means that the queue needs to be filled less
22
            often, but more CPU memory is needed to store the patches.
23
        samples_per_volume: Number of patches to extract from each volume.
24
            A small number of patches ensures a large variability in the queue,
25
            but training will be slower.
26
        sampler: A sampler used to extract patches from the volumes.
27
        num_workers: Number of subprocesses to use for data loading
28
            (as in :class:`torch.utils.data.DataLoader`).
29
            ``0`` means that the data will be loaded in the main process.
30
        shuffle_subjects: If ``True``, the subjects dataset is shuffled at the
31
            beginning of each epoch, i.e. when all patches from all subjects
32
            have been processed.
33
        shuffle_patches: If ``True``, patches are shuffled after filling the
34
            queue.
35
        verbose: If ``True``, some debugging messages are printed.
36
37
    This sketch can be used to experiment and understand how the queue works.
38
    In this case, :attr:`shuffle_subjects` is ``False``
39
    and :attr:`shuffle_patches` is ``True``.
40
41
    .. raw:: html
42
43
        <embed>
44
            <iframe style="width: 640px; height: 360px; overflow: hidden;" scrolling="no" frameborder="0" src="https://editor.p5js.org/embed/DZwjZzkkV"></iframe>
45
        </embed>
46
47
    .. note:: :attr:`num_workers` refers to the number of workers used to
48
        load and transform the volumes. Multiprocessing is not needed to pop
49
        patches from the queue.
50
51
    Example:
52
53
    >>> import torch
54
    >>> import torchio as tio
55
    >>> from torch.utils.data import DataLoader
56
    >>> patch_size = 96
57
    >>> queue_length = 300
58
    >>> samples_per_volume = 10
59
    >>> sampler = tio.data.UniformSampler(patch_size)
60
    >>> subject = tio.datasets.Colin27()
61
    >>> subjects_dataset = tio.SubjectsDataset(10 * [subject])
62
    >>> patches_queue = tio.Queue(
63
    ...     subjects_dataset,
64
    ...     queue_length,
65
    ...     samples_per_volume,
66
    ...     sampler,
67
    ...     num_workers=4,
68
    ... )
69
    >>> patches_loader = DataLoader(patches_queue, batch_size=16)
70
    >>> num_epochs = 2
71
    >>> model = torch.nn.Identity()
72
    >>> for epoch_index in range(num_epochs):
73
    ...     for patches_batch in patches_loader:
74
    ...         inputs = patches_batch['t1'][tio.DATA]  # key 't1' is in subject
75
    ...         targets = patches_batch['brain'][tio.DATA]  # key 'brain' is in subject
76
    ...         logits = model(inputs)  # model being an instance of torch.nn.Module
77
78
    """
79
    def __init__(
80
            self,
81
            subjects_dataset: SubjectsDataset,
82
            max_length: int,
83
            samples_per_volume: int,
84
            sampler: PatchSampler,
85
            num_workers: int = 0,
86
            shuffle_subjects: bool = True,
87
            shuffle_patches: bool = True,
88
            verbose: bool = False,
89
            ):
90
        self.subjects_dataset = subjects_dataset
91
        self.max_length = max_length
92
        self.shuffle_subjects = shuffle_subjects
93
        self.shuffle_patches = shuffle_patches
94
        self.samples_per_volume = samples_per_volume
95
        self.sampler = sampler
96
        self.num_workers = num_workers
97
        self.verbose = verbose
98
        self.subjects_iterable = self.get_subjects_iterable()
99
        self.patches_list: List[dict] = []
100
        self.num_sampled_patches = 0
101
102
    def __len__(self):
103
        return self.iterations_per_epoch
104
105
    def __getitem__(self, _):
106
        # There are probably more elegant ways of doing this
107
        if not self.patches_list:
108
            self._print('Patches list is empty.')
109
            self.fill()
110
        sample_patch = self.patches_list.pop()
111
        self.num_sampled_patches += 1
112
        return sample_patch
113
114
    def __repr__(self):
115
        attributes = [
116
            f'max_length={self.max_length}',
117
            f'num_subjects={self.num_subjects}',
118
            f'num_patches={self.num_patches}',
119
            f'samples_per_volume={self.samples_per_volume}',
120
            f'num_sampled_patches={self.num_sampled_patches}',
121
            f'iterations_per_epoch={self.iterations_per_epoch}',
122
        ]
123
        attributes_string = ', '.join(attributes)
124
        return f'Queue({attributes_string})'
125
126
    def _print(self, *args):
127
        if self.verbose:
128
            print(*args)  # noqa: T001
129
130
    @property
131
    def num_subjects(self) -> int:
132
        return len(self.subjects_dataset)
133
134
    @property
135
    def num_patches(self) -> int:
136
        return len(self.patches_list)
137
138
    @property
139
    def iterations_per_epoch(self) -> int:
140
        return self.num_subjects * self.samples_per_volume
141
142
    def fill(self) -> None:
143
        assert self.sampler is not None
144
        if self.max_length % self.samples_per_volume != 0:
145
            message = (
146
                f'Queue length ({self.max_length})'
147
                ' not divisible by the number of'
148
                f' patches per volume ({self.samples_per_volume})'
149
            )
150
            warnings.warn(message)
151
152
        # If there are e.g. 4 subjects and 1 sample per volume and max_length
153
        # is 6, we just need to load 4 subjects, not 6
154
        max_num_subjects_for_queue = self.max_length // self.samples_per_volume
155
        num_subjects_for_queue = min(
156
            self.num_subjects, max_num_subjects_for_queue)
157
158
        self._print(f'Filling queue from {num_subjects_for_queue} subjects...')
159
        if self.verbose:
160
            iterable = trange(num_subjects_for_queue, leave=False)
161
        else:
162
            iterable = range(num_subjects_for_queue)
163
        for _ in iterable:
164
            subject = self.get_next_subject()
165
            iterable = self.sampler(subject)
166
            patches = list(islice(iterable, self.samples_per_volume))
167
            self.patches_list.extend(patches)
168
        if self.shuffle_patches:
169
            random.shuffle(self.patches_list)
170
171
    def get_next_subject(self) -> Subject:
172
        # A StopIteration exception is expected when the queue is empty
173
        try:
174
            subject = next(self.subjects_iterable)
175
        except StopIteration as exception:
176
            self._print('Queue is empty:', exception)
177
            self.subjects_iterable = self.get_subjects_iterable()
178
            subject = next(self.subjects_iterable)
179
        return subject
180
181
    def get_subjects_iterable(self) -> Iterator:
182
        # I need a DataLoader to handle parallelism
183
        # But this loader is always expected to yield single subject samples
184
        self._print(
185
            '\nCreating subjects loader with', self.num_workers, 'workers')
186
        subjects_loader = DataLoader(
187
            self.subjects_dataset,
188
            num_workers=self.num_workers,
189
            collate_fn=lambda x: x[0],
190
            shuffle=self.shuffle_subjects,
191
        )
192
        return iter(subjects_loader)
193