Completed
Push — master ( 33fb47...58ba59 )
by Fernando
01:05
created

torchio.data.queue.Queue.subjects_iterable()   A

Complexity

Conditions 2

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 5
nop 1
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
import random
2
import warnings
3
from itertools import islice
4
from typing import List, Iterator
5
6
from tqdm import trange
7
from torch.utils.data import Dataset, DataLoader
8
9
from .subject import Subject
10
from .sampler import PatchSampler
11
from .dataset import SubjectsDataset
12
13
14
class Queue(Dataset):
15
    r"""Queue used for stochastic patch-based training.
16
17
    A training iteration (i.e., forward and backward pass) performed on a
18
    GPU is usually faster than loading, preprocessing, augmenting, and cropping
19
    a volume on a CPU.
20
    Most preprocessing operations could be performed using a GPU,
21
    but these devices are typically reserved for training the CNN so that batch
22
    size and input tensor size can be as large as possible.
23
    Therefore, it is beneficial to prepare (i.e., load, preprocess and augment)
24
    the volumes using multiprocessing CPU techniques in parallel with the
25
    forward-backward passes of a training iteration.
26
    Once a volume is appropriately prepared, it is computationally beneficial to
27
    sample multiple patches from a volume rather than having to prepare the same
28
    volume each time a patch needs to be extracted.
29
    The sampled patches are then stored in a buffer or *queue* until
30
    the next training iteration, at which point they are loaded onto the GPU
31
    for inference.
32
    For this, TorchIO provides the :class:`~torchio.data.Queue` class, which also
33
    inherits from the PyTorch :class:`~torch.utils.data.Dataset`.
34
    In this queueing system,
35
    samplers behave as generators that yield patches from random locations
36
    in volumes contained in the :class:`~torchio.data.SubjectsDataset`.
37
38
    The end of a training epoch is defined as the moment after which patches
39
    from all subjects have been used for training.
40
    At the beginning of each training epoch,
41
    the subjects list in the :class:`~torchio.data.SubjectsDataset` is shuffled,
42
    as is typically done in machine learning pipelines to increase variance
43
    of training instances during model optimization.
44
    A PyTorch loader queries the datasets copied in each process,
45
    which load and process the volumes in parallel on the CPU.
46
    A patches list is filled with patches extracted by the sampler,
47
    and the queue is shuffled once it has reached a specified maximum length so
48
    that batches are composed of patches from different subjects.
49
    The internal data loader continues querying the
50
    :class:`~torchio.data.SubjectsDataset` using multiprocessing.
51
    The patches list, when emptied, is refilled with new patches.
52
    A second data loader, external to the queue,
53
    may be used to collate batches of patches stored in the queue,
54
    which are passed to the neural network.
55
56
    Args:
57
        subjects_dataset: Instance of :class:`~torchio.data.SubjectsDataset`.
58
        max_length: Maximum number of patches that can be stored in the queue.
59
            Using a large number means that the queue needs to be filled less
60
            often, but more CPU memory is needed to store the patches.
61
        samples_per_volume: Number of patches to extract from each volume.
62
            A small number of patches ensures a large variability in the queue,
63
            but training will be slower.
64
        sampler: A subclass of :class:`~torchio.data.sampler.PatchSampler` used
65
            to extract patches from the volumes.
66
        num_workers: Number of subprocesses to use for data loading
67
            (as in :class:`torch.utils.data.DataLoader`).
68
            ``0`` means that the data will be loaded in the main process.
69
        pin_memory: See :attr:`pin_memory` in
70
            :class:`~torch.utils.data.DataLoader`.
71
        shuffle_subjects: If ``True``, the subjects dataset is shuffled at the
72
            beginning of each epoch, i.e. when all patches from all subjects
73
            have been processed.
74
        shuffle_patches: If ``True``, patches are shuffled after filling the
75
            queue.
76
        start_background: If ``True``, the loader will start working in the
77
            background as soon as the queue is instantiated.
78
        verbose: If ``True``, some debugging messages will be printed.
79
80
    This diagram represents the connection between
81
    a :class:`~torchio.data.SubjectsDataset`,
82
    a :class:`~torchio.data.Queue`
83
    and the :class:`~torch.utils.data.DataLoader` used to pop batches from the
84
    queue.
85
86
    .. image:: https://raw.githubusercontent.com/fepegar/torchio/master/docs/images/diagram_patches.svg
87
        :alt: Training with patches
88
89
    This sketch can be used to experiment and understand how the queue works.
90
    In this case, :attr:`shuffle_subjects` is ``False``
91
    and :attr:`shuffle_patches` is ``True``.
92
93
    .. raw:: html
94
95
        <embed>
96
            <iframe style="width: 640px; height: 360px; overflow: hidden;" scrolling="no" frameborder="0" src="https://editor.p5js.org/embed/DZwjZzkkV"></iframe>
97
        </embed>
98
99
    .. note:: :attr:`num_workers` refers to the number of workers used to
100
        load and transform the volumes. Multiprocessing is not needed to pop
101
        patches from the queue, so you should always use ``num_workers=0`` for
102
        the :class:`~torch.utils.data.DataLoader` you instantiate to generate
103
        training batches.
104
105
    Example:
106
107
    >>> import torch
108
    >>> import torchio as tio
109
    >>> from torch.utils.data import DataLoader
110
    >>> patch_size = 96
111
    >>> queue_length = 300
112
    >>> samples_per_volume = 10
113
    >>> sampler = tio.data.UniformSampler(patch_size)
114
    >>> subject = tio.datasets.Colin27()
115
    >>> subjects_dataset = tio.SubjectsDataset(10 * [subject])
116
    >>> patches_queue = tio.Queue(
117
    ...     subjects_dataset,
118
    ...     queue_length,
119
    ...     samples_per_volume,
120
    ...     sampler,
121
    ...     num_workers=4,
122
    ... )
123
    >>> patches_loader = DataLoader(patches_queue, batch_size=16)
124
    >>> num_epochs = 2
125
    >>> model = torch.nn.Identity()
126
    >>> for epoch_index in range(num_epochs):
127
    ...     for patches_batch in patches_loader:
128
    ...         inputs = patches_batch['t1'][tio.DATA]  # key 't1' is in subject
129
    ...         targets = patches_batch['brain'][tio.DATA]  # key 'brain' is in subject
130
    ...         logits = model(inputs)  # model being an instance of torch.nn.Module
131
132
    """  # noqa: E501
133
    def __init__(
134
            self,
135
            subjects_dataset: SubjectsDataset,
136
            max_length: int,
137
            samples_per_volume: int,
138
            sampler: PatchSampler,
139
            num_workers: int = 0,
140
            pin_memory: bool = True,
141
            shuffle_subjects: bool = True,
142
            shuffle_patches: bool = True,
143
            start_background: bool = True,
144
            verbose: bool = False,
145
            ):
146
        self.subjects_dataset = subjects_dataset
147
        self.max_length = max_length
148
        self.shuffle_subjects = shuffle_subjects
149
        self.shuffle_patches = shuffle_patches
150
        self.samples_per_volume = samples_per_volume
151
        self.sampler = sampler
152
        self.num_workers = num_workers
153
        self.pin_memory = pin_memory
154
        self.verbose = verbose
155
        self._subjects_iterable = None
156
        if start_background:
157
            self.initialize_subjects_iterable()
158
        self.patches_list: List[Subject] = []
159
        self.num_sampled_patches = 0
160
161
    def __len__(self):
162
        return self.iterations_per_epoch
163
164
    def __getitem__(self, _):
165
        # There are probably more elegant ways of doing this
166
        if not self.patches_list:
167
            self._print('Patches list is empty.')
168
            self.fill()
169
        sample_patch = self.patches_list.pop()
170
        self.num_sampled_patches += 1
171
        return sample_patch
172
173
    def __repr__(self):
174
        attributes = [
175
            f'max_length={self.max_length}',
176
            f'num_subjects={self.num_subjects}',
177
            f'num_patches={self.num_patches}',
178
            f'samples_per_volume={self.samples_per_volume}',
179
            f'num_sampled_patches={self.num_sampled_patches}',
180
            f'iterations_per_epoch={self.iterations_per_epoch}',
181
        ]
182
        attributes_string = ', '.join(attributes)
183
        return f'Queue({attributes_string})'
184
185
    def _print(self, *args):
186
        if self.verbose:
187
            print(*args)  # noqa: T001
188
189
    def initialize_subjects_iterable(self):
190
        self._subjects_iterable = self.get_subjects_iterable()
191
192
    @property
193
    def subjects_iterable(self):
194
        if self._subjects_iterable is None:
195
            self.initialize_subjects_iterable()
196
        return self._subjects_iterable
197
198
    @property
199
    def num_subjects(self) -> int:
200
        return len(self.subjects_dataset)
201
202
    @property
203
    def num_patches(self) -> int:
204
        return len(self.patches_list)
205
206
    @property
207
    def iterations_per_epoch(self) -> int:
208
        return self.num_subjects * self.samples_per_volume
209
210
    def fill(self) -> None:
211
        assert self.sampler is not None
212
        if self.max_length % self.samples_per_volume != 0:
213
            message = (
214
                f'Queue length ({self.max_length})'
215
                ' not divisible by the number of'
216
                f' patches per volume ({self.samples_per_volume})'
217
            )
218
            warnings.warn(message, RuntimeWarning)
219
220
        # If there are e.g. 4 subjects and 1 sample per volume and max_length
221
        # is 6, we just need to load 4 subjects, not 6
222
        max_num_subjects_for_queue = self.max_length // self.samples_per_volume
223
        num_subjects_for_queue = min(
224
            self.num_subjects, max_num_subjects_for_queue)
225
226
        self._print(f'Filling queue from {num_subjects_for_queue} subjects...')
227
        if self.verbose:
228
            iterable = trange(num_subjects_for_queue, leave=False)
229
        else:
230
            iterable = range(num_subjects_for_queue)
231
        for _ in iterable:
232
            subject = self.get_next_subject()
233
            iterable = self.sampler(subject)
234
            patches = list(islice(iterable, self.samples_per_volume))
235
            self.patches_list.extend(patches)
236
        if self.shuffle_patches:
237
            random.shuffle(self.patches_list)
238
239
    def get_next_subject(self) -> Subject:
240
        # A StopIteration exception is expected when the queue is empty
241
        try:
242
            subject = next(self.subjects_iterable)
243
        except StopIteration as exception:
244
            self._print('Queue is empty:', exception)
245
            self.initialize_subjects_iterable()
246
            subject = next(self.subjects_iterable)
247
        return subject
248
249
    @staticmethod
250
    def get_first_item(batch):
251
        return batch[0]
252
253
    def get_subjects_iterable(self) -> Iterator:
254
        # I need a DataLoader to handle parallelism
255
        # But this loader is always expected to yield single subject samples
256
        self._print(
257
            f'\nCreating subjects loader with {self.num_workers} workers')
258
        subjects_loader = DataLoader(
259
            self.subjects_dataset,
260
            num_workers=self.num_workers,
261
            pin_memory=self.pin_memory,
262
            batch_size=1,
263
            collate_fn=self.get_first_item,
264
            shuffle=self.shuffle_subjects,
265
        )
266
        return iter(subjects_loader)
267