Passed
Push — master ( 7b848f...cac223 )
by Fernando
02:39
created

torchio.data.queue.Queue.fill()   B

Complexity

Conditions 5

Size

Total Lines 29
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 21
nop 1
dl 0
loc 29
rs 8.9093
c 0
b 0
f 0
1
import random
2
import warnings
3
from typing import List, Iterator
4
from itertools import islice
5
from tqdm import trange
6
from torch.utils.data import Dataset, DataLoader
7
from .. import TypeTuple
8
from .dataset import ImagesDataset
9
from .sampler import ImageSampler
10
11
12
class Queue(Dataset):
13
    r"""Patches queue used for patch-based training.
14
15
    Args:
16
        subjects_dataset: Instance of
17
            :class:`~torchio.data.dataset.ImagesDataset`.
18
        max_length: Maximum number of patches that can be stored in the queue.
19
            Using a large number means that the queue needs to be filled less
20
            often, but more RAM is needed to store the patches.
21
        samples_per_volume: Number of patches to extract from each volume.
22
            A small number of patches ensures a large variability in the queue,
23
            but training will be slower.
24
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
25
            of size :math:`d \times h \times w`.
26
            If a single number :math:`n` is provided,
27
            :math:`d = h = w = n`.
28
        sampler_class: An instance of :class:`~torchio.data.datasetampler` used
29
            to define the patches sampling strategy.
30
        num_workers: Number of subprocesses to use for data loading
31
            (as in :class:`torch.utils.data.DataLoader`).
32
            ``0`` means that the data will be loaded in the main process.
33
        shuffle_subjects: If ``True``, the subjects dataset is shuffled at the
34
            beginning of each epoch, i.e. when all patches from all subjects
35
            have been processed.
36
        shuffle_patches: If ``True``, patches are shuffled after filling the
37
            queue.
38
        verbose: If ``True``, some debugging messages are printed.
39
40
    .. note:: :attr:`num_workers` refers to the number of workers used to
41
        load and transform the volumes. Multiprocessing is not needed to pop
42
        patches from the queue.
43
44
    Example:
45
46
    >>> from torch.utils.data import DataLoader
47
    >>> import torchio
48
    >>> patches_queue = torchio.Queue(
49
    ...     subjects_dataset=subjects_dataset,  # instance of torchio.ImagesDataset
50
    ...     max_length=300,
51
    ...     samples_per_volume=10,
52
    ...     patch_size=96,
53
    ...     sampler_class=torchio.sampler.ImageSampler,
54
    ...     num_workers=4,
55
    ...     shuffle_subjects=True,
56
    ...     shuffle_patches=True,
57
    ... )
58
    >>> patches_loader = DataLoader(patches_queue, batch_size=4)
59
    >>> num_epochs = 20
60
    >>> for epoch_index in range(num_epochs):
61
    ...     for patches_batch in patches_loader:
62
    ...         inputs = patches_batch['image_name'][torchio.DATA]
63
    ...         targets = patches_batch['targets_name'][torchio.DATA]
64
    ...         logits = model(inputs)  # model is some torch.nn.Module
65
66
    """
67
    def __init__(
68
            self,
69
            subjects_dataset: ImagesDataset,
70
            max_length: int,
71
            samples_per_volume: int,
72
            patch_size: TypeTuple,
73
            sampler_class: ImageSampler,
74
            num_workers: int = 0,
75
            shuffle_subjects: bool = True,
76
            shuffle_patches: bool = True,
77
            verbose: bool = False,
78
            ):
79
        self.subjects_dataset = subjects_dataset
80
        self.max_length = max_length
81
        self.shuffle_subjects = shuffle_subjects
82
        self.shuffle_patches = shuffle_patches
83
        self.samples_per_volume = samples_per_volume
84
        self.sampler_class = sampler_class
85
        self.patch_size = patch_size
86
        self.num_workers = num_workers
87
        self.verbose = verbose
88
        self.subjects_iterable = self.get_subjects_iterable()
89
        self.patches_list: List[dict] = []
90
        self.num_sampled_patches = 0
91
92
    def __len__(self):
93
        return self.iterations_per_epoch
94
95
    def __getitem__(self, _):
96
        # There are probably more elegant ways of doing this
97
        if not self.patches_list:
98
            self.print('Patches list is empty.')
99
            self.fill()
100
        sample_patch = self.patches_list.pop()
101
        self.num_sampled_patches += 1
102
        return sample_patch
103
104
    def __repr__(self):
105
        attributes = [
106
            f'max_length={self.max_length}',
107
            f'num_subjects={self.num_subjects}',
108
            f'num_patches={self.num_patches}',
109
            f'samples_per_volume={self.samples_per_volume}',
110
            f'num_sampled_patches={self.num_sampled_patches}',
111
            f'iterations_per_epoch={self.iterations_per_epoch}',
112
        ]
113
        attributes_string = ', '.join(attributes)
114
        return f'Queue({attributes_string})'
115
116
    def print(self, *args):
117
        if self.verbose:
118
            print(*args)
119
120
    @property
121
    def num_subjects(self) -> int:
122
        return len(self.subjects_dataset)
123
124
    @property
125
    def num_patches(self) -> int:
126
        return len(self.patches_list)
127
128
    @property
129
    def iterations_per_epoch(self) -> int:
130
        return self.num_subjects * self.samples_per_volume
131
132
    def fill(self) -> None:
133
        assert self.sampler_class is not None
134
        assert self.patch_size is not None
135
        if self.max_length % self.samples_per_volume != 0:
136
            message = (
137
                f'Queue length ({self.max_length})'
138
                ' not divisible by the number of'
139
                f' patches per volume ({self.samples_per_volume})'
140
            )
141
            warnings.warn(message)
142
143
        # If there are e.g. 4 subjects and 1 sample per volume and max_length
144
        # is 6, we just need to load 4 subjects, not 6
145
        max_num_subjects_for_queue = self.max_length // self.samples_per_volume
146
        num_subjects_for_queue = min(
147
            self.num_subjects, max_num_subjects_for_queue)
148
149
        self.print(f'Filling queue from {num_subjects_for_queue} subjects...')
150
        if self.verbose:
151
            iterable = trange(num_subjects_for_queue, leave=False)
152
        else:
153
            iterable = range(num_subjects_for_queue)
154
        for _ in iterable:
155
            subject_sample = self.get_next_subject_sample()
156
            sampler = self.sampler_class(subject_sample, self.patch_size)
157
            samples = list(islice(sampler, self.samples_per_volume))
158
            self.patches_list.extend(samples)
159
        if self.shuffle_patches:
160
            random.shuffle(self.patches_list)
161
162
    def get_next_subject_sample(self) -> dict:
163
        # A StopIteration exception is expected when the queue is empty
164
        try:
165
            subject_sample = next(self.subjects_iterable)
166
        except StopIteration as exception:
167
            self.print('Queue is empty:', exception)
168
            self.subjects_iterable = self.get_subjects_iterable()
169
            subject_sample = next(self.subjects_iterable)
170
        return subject_sample
171
172
    def get_subjects_iterable(self) -> Iterator:
173
        # I need a DataLoader to handle parallelism
174
        # But this loader is always expected to yield single subject samples
175
        self.print(
176
            '\nCreating subjects loader with', self.num_workers, 'workers')
177
        subjects_loader = DataLoader(
178
            self.subjects_dataset,
179
            num_workers=self.num_workers,
180
            collate_fn=lambda x: x[0],
181
            shuffle=self.shuffle_subjects,
182
        )
183
        return iter(subjects_loader)
184