Passed
Pull Request — master (#201)
by Fernando
01:01
created

torchio.data.queue.Queue.get_next_subject_sample()   A

Complexity

Conditions 2

Size

Total Lines 9
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 8
nop 1
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
import random
2
import warnings
3
from itertools import islice
4
from typing import List, Iterator
5
6
from tqdm import trange
7
from torch.utils.data import Dataset, DataLoader
8
9
from .sampler import PatchSampler
10
from .dataset import ImagesDataset
11
12
13
class Queue(Dataset):
14
    r"""Patches queue used for patch-based training.
15
16
    Args:
17
        subjects_dataset: Instance of
18
            :class:`~torchio.data.dataset.ImagesDataset`.
19
        max_length: Maximum number of patches that can be stored in the queue.
20
            Using a large number means that the queue needs to be filled less
21
            often, but more CPU memory is needed to store the patches.
22
        samples_per_volume: Number of patches to extract from each volume.
23
            A small number of patches ensures a large variability in the queue,
24
            but training will be slower.
25
        sampler: A sampler used to extract patches from the volumes.
26
        num_workers: Number of subprocesses to use for data loading
27
            (as in :class:`torch.utils.data.DataLoader`).
28
            ``0`` means that the data will be loaded in the main process.
29
        shuffle_subjects: If ``True``, the subjects dataset is shuffled at the
30
            beginning of each epoch, i.e. when all patches from all subjects
31
            have been processed.
32
        shuffle_patches: If ``True``, patches are shuffled after filling the
33
            queue.
34
        verbose: If ``True``, some debugging messages are printed.
35
36
    This sketch can be used to experiment and understand how the queue works.
37
    In this case, :attr:`shuffle_subjects` is ``False``
38
    and :attr:`shuffle_patches` is ``True``.
39
40
    .. raw:: html
41
42
        <embed>
43
            <iframe style="width: 640px; height: 360px; overflow: hidden;" scrolling="no" frameborder="0" src="https://editor.p5js.org/embed/DZwjZzkkV"></iframe>
44
        </embed>
45
46
    .. note:: :attr:`num_workers` refers to the number of workers used to
47
        load and transform the volumes. Multiprocessing is not needed to pop
48
        patches from the queue.
49
50
    Example:
51
52
    >>> from torch.utils.data import DataLoader
53
    >>> import torchio
54
    >>> patches_queue = torchio.Queue(
55
    ...     subjects_dataset=subjects_dataset,  # instance of torchio.ImagesDataset
56
    ...     max_length=300,
57
    ...     samples_per_volume=10,
58
    ...     patch_size=96,
59
    ...     sampler=,
60
    ...     num_workers=4,
61
    ...     shuffle_subjects=True,
62
    ...     shuffle_patches=True,
63
    ... )
64
    >>> patches_loader = DataLoader(patches_queue, batch_size=4)
65
    >>> num_epochs = 20
66
    >>> for epoch_index in range(num_epochs):
67
    ...     for patches_batch in patches_loader:
68
    ...         inputs = patches_batch['image_name'][torchio.DATA]
69
    ...         targets = patches_batch['targets_name'][torchio.DATA]
70
    ...         logits = model(inputs)  # model is some torch.nn.Module
71
72
    """
73
    def __init__(
74
            self,
75
            subjects_dataset: ImagesDataset,
76
            max_length: int,
77
            samples_per_volume: int,
78
            sampler: PatchSampler,
79
            num_workers: int = 0,
80
            shuffle_subjects: bool = True,
81
            shuffle_patches: bool = True,
82
            verbose: bool = False,
83
            ):
84
        self.subjects_dataset = subjects_dataset
85
        self.max_length = max_length
86
        self.shuffle_subjects = shuffle_subjects
87
        self.shuffle_patches = shuffle_patches
88
        self.samples_per_volume = samples_per_volume
89
        self.sampler = sampler
90
        self.num_workers = num_workers
91
        self.verbose = verbose
92
        self.subjects_iterable = self.get_subjects_iterable()
93
        self.patches_list: List[dict] = []
94
        self.num_sampled_patches = 0
95
96
    def __len__(self):
97
        return self.iterations_per_epoch
98
99
    def __getitem__(self, _):
100
        # There are probably more elegant ways of doing this
101
        if not self.patches_list:
102
            self.print('Patches list is empty.')
103
            self.fill()
104
        sample_patch = self.patches_list.pop()
105
        self.num_sampled_patches += 1
106
        return sample_patch
107
108
    def __repr__(self):
109
        attributes = [
110
            f'max_length={self.max_length}',
111
            f'num_subjects={self.num_subjects}',
112
            f'num_patches={self.num_patches}',
113
            f'samples_per_volume={self.samples_per_volume}',
114
            f'num_sampled_patches={self.num_sampled_patches}',
115
            f'iterations_per_epoch={self.iterations_per_epoch}',
116
        ]
117
        attributes_string = ', '.join(attributes)
118
        return f'Queue({attributes_string})'
119
120
    def print(self, *args):
121
        if self.verbose:
122
            print(*args)
123
124
    @property
125
    def num_subjects(self) -> int:
126
        return len(self.subjects_dataset)
127
128
    @property
129
    def num_patches(self) -> int:
130
        return len(self.patches_list)
131
132
    @property
133
    def iterations_per_epoch(self) -> int:
134
        return self.num_subjects * self.samples_per_volume
135
136
    def fill(self) -> None:
137
        assert self.sampler is not None
138
        if self.max_length % self.samples_per_volume != 0:
139
            message = (
140
                f'Queue length ({self.max_length})'
141
                ' not divisible by the number of'
142
                f' patches per volume ({self.samples_per_volume})'
143
            )
144
            warnings.warn(message)
145
146
        # If there are e.g. 4 subjects and 1 sample per volume and max_length
147
        # is 6, we just need to load 4 subjects, not 6
148
        max_num_subjects_for_queue = self.max_length // self.samples_per_volume
149
        num_subjects_for_queue = min(
150
            self.num_subjects, max_num_subjects_for_queue)
151
152
        self.print(f'Filling queue from {num_subjects_for_queue} subjects...')
153
        if self.verbose:
154
            iterable = trange(num_subjects_for_queue, leave=False)
155
        else:
156
            iterable = range(num_subjects_for_queue)
157
        for _ in iterable:
158
            patches = self.get_patches_from_dataset()
159
            self.patches_list.extend(patches)
160
        if self.shuffle_patches:
161
            random.shuffle(self.patches_list)
162
163
    def get_patches_from_dataset(self) -> dict:
164
        # A StopIteration exception is expected when the queue is empty
165
        try:
166
            subject_sample = next(self.subjects_iterable)
167
        except StopIteration as exception:
168
            self.print('Queue is empty:', exception)
169
            self.subjects_iterable = self.get_subjects_iterable()
170
            subject_sample = next(self.subjects_iterable)
171
        return subject_sample
172
173
    def get_subjects_iterable(self) -> Iterator:
174
        def collate_fn(subjects_list):
175
            generator = self.sampler(subjects_list[0])
176
            return list(islice(generator, self.samples_per_volume))
177
        # I need a DataLoader to handle parallelism
178
        # But this loader is always expected to yield single subject samples
179
        self.print(
180
            '\nCreating subjects loader with', self.num_workers, 'workers')
181
        subjects_loader = DataLoader(
182
            self.subjects_dataset,
183
            num_workers=self.num_workers,
184
            collate_fn=collate_fn,
185
            shuffle=self.shuffle_subjects,
186
        )
187
        return iter(subjects_loader)
188