Passed
Push — master ( 2b3601...ce2161 )
by Fernando
01:27
created

torchio.data.queue.Queue._fill()   B

Complexity

Conditions 5

Size

Total Lines 28
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 28
rs 8.9332
c 0
b 0
f 0
cc 5
nop 1
1
import random
2
import warnings
3
from itertools import islice
4
from typing import List, Iterator, Optional
5
6
import humanize
7
from tqdm import trange
8
from torch.utils.data import Dataset, DataLoader
9
10
from .subject import Subject
11
from .sampler import PatchSampler
12
from .dataset import SubjectsDataset
13
14
15
class Queue(Dataset):
16
    r"""Queue used for stochastic patch-based training.
17
18
    A training iteration (i.e., forward and backward pass) performed on a
19
    GPU is usually faster than loading, preprocessing, augmenting, and cropping
20
    a volume on a CPU.
21
    Most preprocessing operations could be performed using a GPU,
22
    but these devices are typically reserved for training the CNN so that batch
23
    size and input tensor size can be as large as possible.
24
    Therefore, it is beneficial to prepare (i.e., load, preprocess and augment)
25
    the volumes using multiprocessing CPU techniques in parallel with the
26
    forward-backward passes of a training iteration.
27
    Once a volume is appropriately prepared, it is computationally beneficial to
28
    sample multiple patches from a volume rather than having to prepare the same
29
    volume each time a patch needs to be extracted.
30
    The sampled patches are then stored in a buffer or *queue* until
31
    the next training iteration, at which point they are loaded onto the GPU
32
    for inference.
33
    For this, TorchIO provides the :class:`~torchio.data.Queue` class, which also
34
    inherits from the PyTorch :class:`~torch.utils.data.Dataset`.
35
    In this queueing system,
36
    samplers behave as generators that yield patches from random locations
37
    in volumes contained in the :class:`~torchio.data.SubjectsDataset`.
38
39
    The end of a training epoch is defined as the moment after which patches
40
    from all subjects have been used for training.
41
    At the beginning of each training epoch,
42
    the subjects list in the :class:`~torchio.data.SubjectsDataset` is shuffled,
43
    as is typically done in machine learning pipelines to increase variance
44
    of training instances during model optimization.
45
    A PyTorch loader queries the datasets copied in each process,
46
    which load and process the volumes in parallel on the CPU.
47
    A patches list is filled with patches extracted by the sampler,
48
    and the queue is shuffled once it has reached a specified maximum length so
49
    that batches are composed of patches from different subjects.
50
    The internal data loader continues querying the
51
    :class:`~torchio.data.SubjectsDataset` using multiprocessing.
52
    The patches list, when emptied, is refilled with new patches.
53
    A second data loader, external to the queue,
54
    may be used to collate batches of patches stored in the queue,
55
    which are passed to the neural network.
56
57
    Args:
58
        subjects_dataset: Instance of :class:`~torchio.data.SubjectsDataset`.
59
        max_length: Maximum number of patches that can be stored in the queue.
60
            Using a large number means that the queue needs to be filled less
61
            often, but more CPU memory is needed to store the patches.
62
        samples_per_volume: Number of patches to extract from each volume.
63
            A small number of patches ensures a large variability in the queue,
64
            but training will be slower.
65
        sampler: A subclass of :class:`~torchio.data.sampler.PatchSampler` used
66
            to extract patches from the volumes.
67
        num_workers: Number of subprocesses to use for data loading
68
            (as in :class:`torch.utils.data.DataLoader`).
69
            ``0`` means that the data will be loaded in the main process.
70
        shuffle_subjects: If ``True``, the subjects dataset is shuffled at the
71
            beginning of each epoch, i.e. when all patches from all subjects
72
            have been processed.
73
        shuffle_patches: If ``True``, patches are shuffled after filling the
74
            queue.
75
        start_background: If ``True``, the loader will start working in the
76
            background as soon as the queue is instantiated.
77
        verbose: If ``True``, some debugging messages will be printed.
78
79
    This diagram represents the connection between
80
    a :class:`~torchio.data.SubjectsDataset`,
81
    a :class:`~torchio.data.Queue`
82
    and the :class:`~torch.utils.data.DataLoader` used to pop batches from the
83
    queue.
84
85
    .. image:: https://raw.githubusercontent.com/fepegar/torchio/master/docs/images/diagram_patches.svg
86
        :alt: Training with patches
87
88
    This sketch can be used to experiment and understand how the queue works.
89
    In this case, :attr:`shuffle_subjects` is ``False``
90
    and :attr:`shuffle_patches` is ``True``.
91
92
    .. raw:: html
93
94
        <embed>
95
            <iframe style="width: 640px; height: 360px; overflow: hidden;" scrolling="no" frameborder="0" src="https://editor.p5js.org/embed/DZwjZzkkV"></iframe>
96
        </embed>
97
98
    .. note:: :attr:`num_workers` refers to the number of workers used to
99
        load and transform the volumes. Multiprocessing is not needed to pop
100
        patches from the queue, so you should always use ``num_workers=0`` for
101
        the :class:`~torch.utils.data.DataLoader` you instantiate to generate
102
        training batches.
103
104
    Example:
105
106
    >>> import torch
107
    >>> import torchio as tio
108
    >>> from torch.utils.data import DataLoader
109
    >>> patch_size = 96
110
    >>> queue_length = 300
111
    >>> samples_per_volume = 10
112
    >>> sampler = tio.data.UniformSampler(patch_size)
113
    >>> subject = tio.datasets.Colin27()
114
    >>> subjects_dataset = tio.SubjectsDataset(10 * [subject])
115
    >>> patches_queue = tio.Queue(
116
    ...     subjects_dataset,
117
    ...     queue_length,
118
    ...     samples_per_volume,
119
    ...     sampler,
120
    ...     num_workers=4,
121
    ... )
122
    >>> patches_loader = DataLoader(patches_queue, batch_size=16)
123
    >>> num_epochs = 2
124
    >>> model = torch.nn.Identity()
125
    >>> for epoch_index in range(num_epochs):
126
    ...     for patches_batch in patches_loader:
127
    ...         inputs = patches_batch['t1'][tio.DATA]  # key 't1' is in subject
128
    ...         targets = patches_batch['brain'][tio.DATA]  # key 'brain' is in subject
129
    ...         logits = model(inputs)  # model being an instance of torch.nn.Module
130
131
    """  # noqa: E501
132
    def __init__(
133
            self,
134
            subjects_dataset: SubjectsDataset,
135
            max_length: int,
136
            samples_per_volume: int,
137
            sampler: PatchSampler,
138
            num_workers: int = 0,
139
            shuffle_subjects: bool = True,
140
            shuffle_patches: bool = True,
141
            start_background: bool = True,
142
            verbose: bool = False,
143
            ):
144
        self.subjects_dataset = subjects_dataset
145
        self.max_length = max_length
146
        self.shuffle_subjects = shuffle_subjects
147
        self.shuffle_patches = shuffle_patches
148
        self.samples_per_volume = samples_per_volume
149
        self.sampler = sampler
150
        self.num_workers = num_workers
151
        self.verbose = verbose
152
        self._subjects_iterable = None
153
        if start_background:
154
            self._initialize_subjects_iterable()
155
        self.patches_list: List[Subject] = []
156
        self.num_sampled_patches = 0
157
158
    def __len__(self):
159
        return self.iterations_per_epoch
160
161
    def __getitem__(self, _):
162
        # There are probably more elegant ways of doing this
163
        if not self.patches_list:
164
            self._print('Patches list is empty.')
165
            self._fill()
166
        sample_patch = self.patches_list.pop()
167
        self.num_sampled_patches += 1
168
        return sample_patch
169
170
    def __repr__(self):
171
        attributes = [
172
            f'max_length={self.max_length}',
173
            f'num_subjects={self.num_subjects}',
174
            f'num_patches={self.num_patches}',
175
            f'samples_per_volume={self.samples_per_volume}',
176
            f'num_sampled_patches={self.num_sampled_patches}',
177
            f'iterations_per_epoch={self.iterations_per_epoch}',
178
        ]
179
        attributes_string = ', '.join(attributes)
180
        return f'Queue({attributes_string})'
181
182
    def _print(self, *args):
183
        if self.verbose:
184
            print(*args)  # noqa: T001
185
186
    def _initialize_subjects_iterable(self):
187
        self._subjects_iterable = self._get_subjects_iterable()
188
189
    @property
190
    def subjects_iterable(self):
191
        if self._subjects_iterable is None:
192
            self._initialize_subjects_iterable()
193
        return self._subjects_iterable
194
195
    @property
196
    def num_subjects(self) -> int:
197
        return len(self.subjects_dataset)
198
199
    @property
200
    def num_patches(self) -> int:
201
        return len(self.patches_list)
202
203
    @property
204
    def iterations_per_epoch(self) -> int:
205
        return self.num_subjects * self.samples_per_volume
206
207
    def _fill(self) -> None:
208
        assert self.sampler is not None
209
        if self.max_length % self.samples_per_volume != 0:
210
            message = (
211
                f'Queue length ({self.max_length})'
212
                ' not divisible by the number of'
213
                f' patches per volume ({self.samples_per_volume})'
214
            )
215
            warnings.warn(message, RuntimeWarning)
216
217
        # If there are e.g. 4 subjects and 1 sample per volume and max_length
218
        # is 6, we just need to load 4 subjects, not 6
219
        max_num_subjects_for_queue = self.max_length // self.samples_per_volume
220
        num_subjects_for_queue = min(
221
            self.num_subjects, max_num_subjects_for_queue)
222
223
        self._print(f'Filling queue from {num_subjects_for_queue} subjects...')
224
        if self.verbose:
225
            iterable = trange(num_subjects_for_queue, leave=False)
226
        else:
227
            iterable = range(num_subjects_for_queue)
228
        for _ in iterable:
229
            subject = self._get_next_subject()
230
            iterable = self.sampler(subject)
231
            patches = list(islice(iterable, self.samples_per_volume))
232
            self.patches_list.extend(patches)
233
        if self.shuffle_patches:
234
            random.shuffle(self.patches_list)
235
236
    def _get_next_subject(self) -> Subject:
237
        # A StopIteration exception is expected when the queue is empty
238
        try:
239
            subject = next(self.subjects_iterable)
240
        except StopIteration as exception:
241
            self._print('Queue is empty:', exception)
242
            self._initialize_subjects_iterable()
243
            subject = next(self.subjects_iterable)
244
        return subject
245
246
    @staticmethod
247
    def _get_first_item(batch):
248
        return batch[0]
249
250
    def _get_subjects_iterable(self) -> Iterator:
251
        # I need a DataLoader to handle parallelism
252
        # But this loader is always expected to yield single subject samples
253
        self._print(
254
            f'\nCreating subjects loader with {self.num_workers} workers')
255
        subjects_loader = DataLoader(
256
            self.subjects_dataset,
257
            num_workers=self.num_workers,
258
            batch_size=1,
259
            collate_fn=self._get_first_item,
260
            shuffle=self.shuffle_subjects,
261
        )
262
        return iter(subjects_loader)
263
264
    def get_max_memory(self, subject: Optional[Subject] = None) -> int:
265
        """Get the maximum RAM occupied by the patches queue in bytes.
266
267
        Args:
268
            subject: Sample subject to compute the size of a patch.
269
        """
270
        images_channels = 0
271
        if subject is None:
272
            subject = self.subjects_dataset[0]
273
        for image in subject.get_images(intensity_only=False):
274
            images_channels += len(image.data)
275
        voxels_in_patch = int(self.sampler.patch_size.prod() * images_channels)
276
        bytes_per_patch = 4 * voxels_in_patch  # assume float32
277
        return int(bytes_per_patch * self.max_length)
278
279
    def get_max_memory_pretty(self, subject: Optional[Subject] = None) -> str:
280
        """Get human-readable maximum RAM occupied by the patches queue.
281
282
        Args:
283
            subject: Sample subject to compute the size of a patch.
284
        """
285
        memory = self.get_max_memory(subject=subject)
286
        return humanize.naturalsize(memory, binary=True)
287