Passed
Pull Request — master (#609)
by Fernando
01:18
created

Queue._parse_samples_per_volume()   A

Complexity

Conditions 5

Size

Total Lines 19
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 14
dl 0
loc 19
rs 9.2333
c 0
b 0
f 0
cc 5
nop 2
1
import random
2
import warnings
3
from itertools import islice
4
from typing import List, Iterator, Optional, Sequence, Union
5
6
import humanize
7
from torch.utils.data import Dataset, DataLoader, RandomSampler
8
9
from .subject import Subject
10
from .sampler import PatchSampler
11
from .dataset import SubjectsDataset
12
13
14
class Queue(Dataset):
15
    r"""Queue used for stochastic patch-based training.
16
17
    A training iteration (i.e., forward and backward pass) performed on a
18
    GPU is usually faster than loading, preprocessing, augmenting, and cropping
19
    a volume on a CPU.
20
    Most preprocessing operations could be performed using a GPU,
21
    but these devices are typically reserved for training the CNN so that batch
22
    size and input tensor size can be as large as possible.
23
    Therefore, it is beneficial to prepare (i.e., load, preprocess and augment)
24
    the volumes using multiprocessing CPU techniques in parallel with the
25
    forward-backward passes of a training iteration.
26
    Once a volume is appropriately prepared, it is computationally beneficial to
27
    sample multiple patches from a volume rather than having to prepare the same
28
    volume each time a patch needs to be extracted.
29
    The sampled patches are then stored in a buffer or *queue* until
30
    the next training iteration, at which point they are loaded onto the GPU
31
    for inference.
32
    For this, TorchIO provides the :class:`~torchio.data.Queue` class, which also
33
    inherits from the PyTorch :class:`~torch.utils.data.Dataset`.
34
    In this queueing system,
35
    samplers behave as generators that yield patches from random locations
36
    in volumes contained in the :class:`~torchio.data.SubjectsDataset`.
37
38
    The end of a training epoch is defined as the moment after which patches
39
    from all subjects have been used for training.
40
    At the beginning of each training epoch,
41
    the subjects list in the :class:`~torchio.data.SubjectsDataset` is shuffled,
42
    as is typically done in machine learning pipelines to increase variance
43
    of training instances during model optimization.
44
    A PyTorch loader queries the datasets copied in each process,
45
    which load and process the volumes in parallel on the CPU.
46
    A patches list is filled with patches extracted by the sampler,
47
    and the queue is shuffled once it has reached a specified maximum length so
48
    that batches are composed of patches from different subjects.
49
    The internal data loader continues querying the
50
    :class:`~torchio.data.SubjectsDataset` using multiprocessing.
51
    The patches list, when emptied, is refilled with new patches.
52
    A second data loader, external to the queue,
53
    may be used to collate batches of patches stored in the queue,
54
    which are passed to the neural network.
55
56
    Args:
57
        subjects_dataset: Instance of :class:`~torchio.data.SubjectsDataset`.
58
        max_length: Maximum number of patches that can be stored in the queue.
59
            Using a large number means that the queue needs to be filled less
60
            often, but more CPU memory is needed to store the patches.
61
        samples_per_volume: List with the number of patches to extract from each
62
            volume. If a number is given, the same number of patches per volume
63
            is used. A small number of patches ensures a large variability in
64
            the queue, but training will be slower.
65
        sampler: A subclass of :class:`~torchio.data.sampler.PatchSampler` used
66
            to extract patches from the volumes.
67
        num_workers: Number of subprocesses to use for data loading
68
            (as in :class:`torch.utils.data.DataLoader`).
69
            ``0`` means that the data will be loaded in the main process.
70
        shuffle_subjects: If ``True``, the subjects dataset is shuffled at the
71
            beginning of each epoch, i.e. when all patches from all subjects
72
            have been processed.
73
        shuffle_patches: If ``True``, patches are shuffled after filling the
74
            queue.
75
        start_background: If ``True``, the loader will start working in the
76
            background as soon as the queue is instantiated.
77
        verbose: If ``True``, some debugging messages will be printed.
78
79
    This diagram represents the connection between
80
    a :class:`~torchio.data.SubjectsDataset`,
81
    a :class:`~torchio.data.Queue`
82
    and the :class:`~torch.utils.data.DataLoader` used to pop batches from the
83
    queue.
84
85
    .. image:: https://raw.githubusercontent.com/fepegar/torchio/master/docs/images/diagram_patches.svg
86
        :alt: Training with patches
87
88
    This sketch can be used to experiment and understand how the queue works.
89
    In this case, :attr:`shuffle_subjects` is ``False``
90
    and :attr:`shuffle_patches` is ``True``.
91
92
    .. raw:: html
93
94
        <embed>
95
            <iframe style="width: 640px; height: 360px; overflow: hidden;" scrolling="no" frameborder="0" src="https://editor.p5js.org/embed/DZwjZzkkV"></iframe>
96
        </embed>
97
98
    .. note:: :attr:`num_workers` refers to the number of workers used to
99
        load and transform the volumes. Multiprocessing is not needed to pop
100
        patches from the queue, so you should always use ``num_workers=0`` for
101
        the :class:`~torch.utils.data.DataLoader` you instantiate to generate
102
        training batches.
103
104
    Example:
105
106
    >>> import torch
107
    >>> import torchio as tio
108
    >>> from torch.utils.data import DataLoader
109
    >>> patch_size = 96
110
    >>> queue_length = 300
111
    >>> samples_per_volume = 10
112
    >>> sampler = tio.data.UniformSampler(patch_size)
113
    >>> subject = tio.datasets.Colin27()
114
    >>> subjects_dataset = tio.SubjectsDataset(10 * [subject])
115
    >>> patches_queue = tio.Queue(
116
    ...     subjects_dataset,
117
    ...     queue_length,
118
    ...     samples_per_volume,
119
    ...     sampler,
120
    ...     num_workers=4,
121
    ... )
122
    >>> patches_loader = DataLoader(patches_queue, batch_size=16)
123
    >>> num_epochs = 2
124
    >>> model = torch.nn.Identity()
125
    >>> for epoch_index in range(num_epochs):
126
    ...     for patches_batch in patches_loader:
127
    ...         inputs = patches_batch['t1'][tio.DATA]  # key 't1' is in subject
128
    ...         targets = patches_batch['brain'][tio.DATA]  # key 'brain' is in subject
129
    ...         logits = model(inputs)  # model being an instance of torch.nn.Module
130
131
    """  # noqa: E501
132
    def __init__(
133
            self,
134
            subjects_dataset: SubjectsDataset,
135
            max_length: int,
136
            samples_per_volume: Union[int, Sequence[int]],
137
            sampler: PatchSampler,
138
            num_workers: int = 0,
139
            shuffle_subjects: bool = True,
140
            shuffle_patches: bool = True,
141
            start_background: bool = True,
142
            verbose: bool = False,
143
            ):
144
        self.subjects_dataset = subjects_dataset
145
        self.max_length = max_length
146
        self.shuffle_subjects = shuffle_subjects
147
        self.shuffle_patches = shuffle_patches
148
        self.samples_per_volume = self._parse_samples_per_volume(
149
            samples_per_volume)
150
        self.sampler = sampler
151
        self.num_workers = num_workers
152
        self.verbose = verbose
153
        self._subjects_iterable = None
154
        self.patches_list: List[Subject] = []
155
        self.num_sampled_patches = 0
156
157
        if start_background:
158
            self._initialize_subjects_iterable()
159
160
        # Keeps a list of the remaining patches to be extracted
161
        self.counter_samples_per_volume = self.samples_per_volume.copy()
162
        # Helps keeping track of which subject it needs to extract patches
163
        self.idx_subject = -1
164
        # Subject. Save as an object property to save computations later
165
        # (more details in _fill())
166
        self.curr_subject = None
167
168
    def __len__(self):
169
        return self.iterations_per_epoch
170
171
    def __getitem__(self, _):
172
        # There are probably more elegant ways of doing this
173
        if not self.patches_list:
174
            self._print('Patches list is empty.')
175
            self._fill()
176
        sample_patch = self.patches_list.pop()
177
        self.num_sampled_patches += 1
178
        return sample_patch
179
180
    def __repr__(self):
181
        attributes = [
182
            f'max_length={self.max_length}',
183
            f'num_subjects={self.num_subjects}',
184
            f'num_patches={self.num_patches}',
185
            f'samples_per_volume={self.samples_per_volume}',
186
            f'num_sampled_patches={self.num_sampled_patches}',
187
            f'iterations_per_epoch={self.iterations_per_epoch}',
188
        ]
189
        attributes_string = ', '.join(attributes)
190
        return f'Queue({attributes_string})'
191
192
    def _parse_samples_per_volume(self, samples_per_volume):
193
        if isinstance(samples_per_volume, int):
194
            samples_per_volume = self.num_subjects * [samples_per_volume]
195
        message = (
196
            'The value of samples_per_volume must be an integer'
197
            ' or a sequence of integers'
198
        )
199
        if isinstance(samples_per_volume, Sequence):
200
            if not all(isinstance(n, int) for n in samples_per_volume):
201
                raise TypeError(message)
202
        else:
203
            raise TypeError(message)
204
        if len(samples_per_volume) != self.num_subjects:
205
            message = (
206
                'The length of samples_per_volume must be equal to the number'
207
                ' of subjects in the subjects dataset'
208
            )
209
            raise ValueError(message)
210
        return samples_per_volume
211
212
    def _print(self, *args):
213
        if self.verbose:
214
            print(*args)  # noqa: T001
215
216
    def _initialize_subjects_iterable(self):
217
        self._subjects_iterable = self._get_subjects_iterable()
218
219
    @property
220
    def subjects_iterable(self):
221
        if self._subjects_iterable is None:
222
            self._initialize_subjects_iterable()
223
        return self._subjects_iterable
224
225
    @property
226
    def num_subjects(self) -> int:
227
        return len(self.subjects_dataset)
228
229
    @property
230
    def num_patches(self) -> int:
231
        return len(self.patches_list)
232
233
    @property
234
    def iterations_per_epoch(self) -> int:
235
        return sum(self.samples_per_volume)
236
237
    def _fill(self) -> None:
238
        assert self.sampler is not None
239
        if self.max_length % self.iterations_per_epoch != 0:
240
            message = (
241
                f'Queue length ({self.max_length})'
242
                ' not divisible by the number of'
243
                f' patches per volume ({self.samples_per_volume})'
244
            )
245
            warnings.warn(message, RuntimeWarning)
246
247
        # If the counter of samples per volume is empty (i.e., end of the
248
        # epoch), refill it.
249
        if sum(self.counter_samples_per_volume) == 0:
250
            self._initialize_subjects_iterable()
251
            self.counter_samples_per_volume = self.samples_per_volume.copy()
252
            self.idx_subject = -1
253
            self.curr_subject = None
254
255
        # Add patches
256
        # 3 stopping conditions (OR):
257
        #   1) The number of current patches in patches_list >= max patches
258
        #   2) There are no more patches that need to be added
259
        #      (i.e., remaining patches -> 0)
260
        #   3) There are no more subjects to extract patches.
261
        while (len(self.patches_list) < self.max_length
262
                and sum(self.counter_samples_per_volume) != 0
263
                and self.idx_subject < self.num_subjects):
264
265
            if (self.curr_subject is None
266
                    or self.counter_samples_per_volume[self.idx_subject] == 0):
267
268
                self.curr_subject = self._get_next_subject()
269
                self.idx_subject += 1
270
271
            # Whether to fill the Queue with a "portion" of patches
272
            # of a specific subject, or all patches of that subject.
273
            if (len(self.patches_list)
274
                    + self.counter_samples_per_volume[self.idx_subject]
275
                    > self.max_length):
276
                # Take a portion
277
                spv = self.max_length - len(self.patches_list)
278
            else:
279
                spv = self.counter_samples_per_volume[self.idx_subject]
280
281
            self.counter_samples_per_volume[self.idx_subject] -= spv
282
            iterable = self.sampler(self.curr_subject)
283
            patches = list(islice(iterable, spv))
284
            self.patches_list.extend(patches)
285
286
        if self.shuffle_patches:
287
            random.shuffle(self.patches_list)
288
        else:
289
            # Reverse the order of the patches so that list().pop starts
290
            # from the beginning
291
            self.patches_list = self.patches_list[::-1]
292
293
    def _get_next_subject(self) -> Subject:
294
        # A StopIteration exception is expected when the queue is empty
295
        try:
296
            subject = next(self.subjects_iterable)
297
        except StopIteration as exception:
298
            self._print('Queue is empty:', exception)
299
            self._initialize_subjects_iterable()
300
            subject = next(self.subjects_iterable)
301
        return subject
302
303
    @staticmethod
304
    def _get_first_item(batch):
305
        return batch[0]
306
307
    def _get_subjects_iterable(self) -> Iterator:
308
        # I need a DataLoader to handle parallelism
309
        # But this loader is always expected to yield single subject samples
310
311
        # Same random shuffling applied to subjects and volumes
312
        if self.shuffle_subjects:
313
            random_idx = list(RandomSampler(self.subjects_dataset))
314
            local_sub_dataset = [self.subjects_dataset[i] for i in random_idx]
315
            self.samples_per_volume = [self.samples_per_volume[i]
316
                                       for i in random_idx]
317
        else:
318
            local_sub_dataset = self.subjects_dataset
319
320
        self._print(
321
            f'\nCreating subjects loader with {self.num_workers} workers')
322
        subjects_loader = DataLoader(
323
            local_sub_dataset,
324
            num_workers=self.num_workers,
325
            batch_size=1,
326
            collate_fn=self._get_first_item,
327
            shuffle=False,  # shuffling is done in _get_subjects_iterable
328
        )
329
        return iter(subjects_loader)
330
331
    def get_max_memory(self, subject: Optional[Subject] = None) -> int:
332
        """Get the maximum RAM occupied by the patches queue in bytes.
333
334
        Args:
335
            subject: Sample subject to compute the size of a patch.
336
        """
337
        images_channels = 0
338
        if subject is None:
339
            subject = self.subjects_dataset[0]
340
        for image in subject.get_images(intensity_only=False):
341
            images_channels += len(image.data)
342
        voxels_in_patch = int(self.sampler.patch_size.prod() * images_channels)
343
        bytes_per_patch = 4 * voxels_in_patch  # assume float32
344
        return int(bytes_per_patch * self.max_length)
345
346
    def get_max_memory_pretty(self, subject: Optional[Subject] = None) -> str:
347
        """Get human-readable maximum RAM occupied by the patches queue.
348
349
        Args:
350
            subject: Sample subject to compute the size of a patch.
351
        """
352
        memory = self.get_max_memory(subject=subject)
353
        return humanize.naturalsize(memory, binary=True)
354