Passed
Push — master ( 3c93a7...8f7525 )
by Fernando
01:09
created

torchio.data.queue.Queue.print()   A

Complexity

Conditions 2

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 3
nop 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import random
2
import warnings
3
from itertools import islice
4
from typing import List, Iterator
5
6
from tqdm import trange
7
from torch.utils.data import Dataset, DataLoader
8
9
from .sampler import PatchSampler
10
from .dataset import SubjectsDataset
11
12
13
class Queue(Dataset):
14
    r"""Patches queue used for patch-based training.
15
16
    Args:
17
        subjects_dataset: Instance of
18
            :class:`~torchio.data.dataset.SubjectsDataset`.
19
        max_length: Maximum number of patches that can be stored in the queue.
20
            Using a large number means that the queue needs to be filled less
21
            often, but more CPU memory is needed to store the patches.
22
        samples_per_volume: Number of patches to extract from each volume.
23
            A small number of patches ensures a large variability in the queue,
24
            but training will be slower.
25
        sampler: A sampler used to extract patches from the volumes.
26
        num_workers: Number of subprocesses to use for data loading
27
            (as in :class:`torch.utils.data.DataLoader`).
28
            ``0`` means that the data will be loaded in the main process.
29
        shuffle_subjects: If ``True``, the subjects dataset is shuffled at the
30
            beginning of each epoch, i.e. when all patches from all subjects
31
            have been processed.
32
        shuffle_patches: If ``True``, patches are shuffled after filling the
33
            queue.
34
        verbose: If ``True``, some debugging messages are printed.
35
36
    This sketch can be used to experiment and understand how the queue works.
37
    In this case, :attr:`shuffle_subjects` is ``False``
38
    and :attr:`shuffle_patches` is ``True``.
39
40
    .. raw:: html
41
42
        <embed>
43
            <iframe style="width: 640px; height: 360px; overflow: hidden;" scrolling="no" frameborder="0" src="https://editor.p5js.org/embed/DZwjZzkkV"></iframe>
44
        </embed>
45
46
    .. note:: :attr:`num_workers` refers to the number of workers used to
47
        load and transform the volumes. Multiprocessing is not needed to pop
48
        patches from the queue.
49
50
    Example:
51
52
    >>> from torch.utils.data import DataLoader
53
    >>> import torchio
54
    >>> patch_size = 96
55
    >>> queue_length = 300
56
    >>> samples_per_volume = 10
57
    >>> sampler = torchio.data.UniformSampler(patch_size)
58
    >>> patches_queue = torchio.Queue(
59
    ...     subjects_dataset,  # instance of torchio.SubjectsDataset
60
    ...     queue_length,
61
    ...     samples_per_volume,
62
    ...     sampler,
63
    ...     num_workers=4,
64
    ...     shuffle_subjects=True,
65
    ...     shuffle_patches=True,
66
    ... )
67
    >>> patches_loader = DataLoader(patches_queue, batch_size=16)
68
    >>> num_epochs = 20
69
    >>> for epoch_index in range(num_epochs):
70
    ...     for patches_batch in patches_loader:
71
    ...         inputs = patches_batch['image_name'][torchio.DATA]
72
    ...         targets = patches_batch['targets_name'][torchio.DATA]
73
    ...         logits = model(inputs)  # model being an instance of torch.nn.Module
74
75
    """
76
    def __init__(
77
            self,
78
            subjects_dataset: SubjectsDataset,
79
            max_length: int,
80
            samples_per_volume: int,
81
            sampler: PatchSampler,
82
            num_workers: int = 0,
83
            shuffle_subjects: bool = True,
84
            shuffle_patches: bool = True,
85
            verbose: bool = False,
86
            ):
87
        self.subjects_dataset = subjects_dataset
88
        self.max_length = max_length
89
        self.shuffle_subjects = shuffle_subjects
90
        self.shuffle_patches = shuffle_patches
91
        self.samples_per_volume = samples_per_volume
92
        self.sampler = sampler
93
        self.num_workers = num_workers
94
        self.verbose = verbose
95
        self.subjects_iterable = self.get_subjects_iterable()
96
        self.patches_list: List[dict] = []
97
        self.num_sampled_patches = 0
98
99
    def __len__(self):
100
        return self.iterations_per_epoch
101
102
    def __getitem__(self, _):
103
        # There are probably more elegant ways of doing this
104
        if not self.patches_list:
105
            self._print('Patches list is empty.')
106
            self.fill()
107
        sample_patch = self.patches_list.pop()
108
        self.num_sampled_patches += 1
109
        return sample_patch
110
111
    def __repr__(self):
112
        attributes = [
113
            f'max_length={self.max_length}',
114
            f'num_subjects={self.num_subjects}',
115
            f'num_patches={self.num_patches}',
116
            f'samples_per_volume={self.samples_per_volume}',
117
            f'num_sampled_patches={self.num_sampled_patches}',
118
            f'iterations_per_epoch={self.iterations_per_epoch}',
119
        ]
120
        attributes_string = ', '.join(attributes)
121
        return f'Queue({attributes_string})'
122
123
    def _print(self, *args):
124
        if self.verbose:
125
            print(*args)  # noqa: T001
126
127
    @property
128
    def num_subjects(self) -> int:
129
        return len(self.subjects_dataset)
130
131
    @property
132
    def num_patches(self) -> int:
133
        return len(self.patches_list)
134
135
    @property
136
    def iterations_per_epoch(self) -> int:
137
        return self.num_subjects * self.samples_per_volume
138
139
    def fill(self) -> None:
140
        assert self.sampler is not None
141
        if self.max_length % self.samples_per_volume != 0:
142
            message = (
143
                f'Queue length ({self.max_length})'
144
                ' not divisible by the number of'
145
                f' patches per volume ({self.samples_per_volume})'
146
            )
147
            warnings.warn(message)
148
149
        # If there are e.g. 4 subjects and 1 sample per volume and max_length
150
        # is 6, we just need to load 4 subjects, not 6
151
        max_num_subjects_for_queue = self.max_length // self.samples_per_volume
152
        num_subjects_for_queue = min(
153
            self.num_subjects, max_num_subjects_for_queue)
154
155
        self._print(f'Filling queue from {num_subjects_for_queue} subjects...')
156
        if self.verbose:
157
            iterable = trange(num_subjects_for_queue, leave=False)
158
        else:
159
            iterable = range(num_subjects_for_queue)
160
        for _ in iterable:
161
            subject_sample = self.get_next_subject_sample()
162
            iterable = self.sampler(subject_sample)
163
            patches = list(islice(iterable, self.samples_per_volume))
164
            self.patches_list.extend(patches)
165
        if self.shuffle_patches:
166
            random.shuffle(self.patches_list)
167
168
    def get_next_subject_sample(self) -> dict:
169
        # A StopIteration exception is expected when the queue is empty
170
        try:
171
            subject_sample = next(self.subjects_iterable)
172
        except StopIteration as exception:
173
            self._print('Queue is empty:', exception)
174
            self.subjects_iterable = self.get_subjects_iterable()
175
            subject_sample = next(self.subjects_iterable)
176
        return subject_sample
177
178
    def get_subjects_iterable(self) -> Iterator:
179
        # I need a DataLoader to handle parallelism
180
        # But this loader is always expected to yield single subject samples
181
        self._print(
182
            '\nCreating subjects loader with', self.num_workers, 'workers')
183
        subjects_loader = DataLoader(
184
            self.subjects_dataset,
185
            num_workers=self.num_workers,
186
            collate_fn=lambda x: x[0],
187
            shuffle=self.shuffle_subjects,
188
        )
189
        return iter(subjects_loader)
190