Passed
Push — master ( 3c93a7...8f7525 )
by Fernando
01:09
created

torchio.data.queue.Queue.__init__()   A

Complexity

Conditions 1

Size

Total Lines 22
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 21
nop 9
dl 0
loc 22
rs 9.376
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import random
2
import warnings
3
from itertools import islice
4
from typing import List, Iterator
5
6
from tqdm import trange
7
from torch.utils.data import Dataset, DataLoader
8
9
from .sampler import PatchSampler
10
from .dataset import SubjectsDataset
11
12
13
class Queue(Dataset):
14
    r"""Patches queue used for patch-based training.
15
16
    Args:
17
        subjects_dataset: Instance of
18
            :class:`~torchio.data.dataset.SubjectsDataset`.
19
        max_length: Maximum number of patches that can be stored in the queue.
20
            Using a large number means that the queue needs to be filled less
21
            often, but more CPU memory is needed to store the patches.
22
        samples_per_volume: Number of patches to extract from each volume.
23
            A small number of patches ensures a large variability in the queue,
24
            but training will be slower.
25
        sampler: A sampler used to extract patches from the volumes.
26
        num_workers: Number of subprocesses to use for data loading
27
            (as in :class:`torch.utils.data.DataLoader`).
28
            ``0`` means that the data will be loaded in the main process.
29
        shuffle_subjects: If ``True``, the subjects dataset is shuffled at the
30
            beginning of each epoch, i.e. when all patches from all subjects
31
            have been processed.
32
        shuffle_patches: If ``True``, patches are shuffled after filling the
33
            queue.
34
        verbose: If ``True``, some debugging messages are printed.
35
36
    This sketch can be used to experiment and understand how the queue works.
37
    In this case, :attr:`shuffle_subjects` is ``False``
38
    and :attr:`shuffle_patches` is ``True``.
39
40
    .. raw:: html
41
42
        <embed>
43
            <iframe style="width: 640px; height: 360px; overflow: hidden;" scrolling="no" frameborder="0" src="https://editor.p5js.org/embed/DZwjZzkkV"></iframe>
44
        </embed>
45
46
    .. note:: :attr:`num_workers` refers to the number of workers used to
47
        load and transform the volumes. Multiprocessing is not needed to pop
48
        patches from the queue.
49
50
    Example:
51
52
    >>> from torch.utils.data import DataLoader
53
    >>> import torchio
54
    >>> patch_size = 96
55
    >>> queue_length = 300
56
    >>> samples_per_volume = 10
57
    >>> sampler = torchio.data.UniformSampler(patch_size)
58
    >>> patches_queue = torchio.Queue(
59
    ...     subjects_dataset,  # instance of torchio.SubjectsDataset
60
    ...     queue_length,
61
    ...     samples_per_volume,
62
    ...     sampler,
63
    ...     num_workers=4,
64
    ...     shuffle_subjects=True,
65
    ...     shuffle_patches=True,
66
    ... )
67
    >>> patches_loader = DataLoader(patches_queue, batch_size=16)
68
    >>> num_epochs = 20
69
    >>> for epoch_index in range(num_epochs):
70
    ...     for patches_batch in patches_loader:
71
    ...         inputs = patches_batch['image_name'][torchio.DATA]
72
    ...         targets = patches_batch['targets_name'][torchio.DATA]
73
    ...         logits = model(inputs)  # model being an instance of torch.nn.Module
74
75
    """
76
    def __init__(
77
            self,
78
            subjects_dataset: SubjectsDataset,
79
            max_length: int,
80
            samples_per_volume: int,
81
            sampler: PatchSampler,
82
            num_workers: int = 0,
83
            shuffle_subjects: bool = True,
84
            shuffle_patches: bool = True,
85
            verbose: bool = False,
86
            ):
87
        self.subjects_dataset = subjects_dataset
88
        self.max_length = max_length
89
        self.shuffle_subjects = shuffle_subjects
90
        self.shuffle_patches = shuffle_patches
91
        self.samples_per_volume = samples_per_volume
92
        self.sampler = sampler
93
        self.num_workers = num_workers
94
        self.verbose = verbose
95
        self.subjects_iterable = self.get_subjects_iterable()
96
        self.patches_list: List[dict] = []
97
        self.num_sampled_patches = 0
98
99
    def __len__(self):
100
        return self.iterations_per_epoch
101
102
    def __getitem__(self, _):
103
        # There are probably more elegant ways of doing this
104
        if not self.patches_list:
105
            self._print('Patches list is empty.')
106
            self.fill()
107
        sample_patch = self.patches_list.pop()
108
        self.num_sampled_patches += 1
109
        return sample_patch
110
111
    def __repr__(self):
112
        attributes = [
113
            f'max_length={self.max_length}',
114
            f'num_subjects={self.num_subjects}',
115
            f'num_patches={self.num_patches}',
116
            f'samples_per_volume={self.samples_per_volume}',
117
            f'num_sampled_patches={self.num_sampled_patches}',
118
            f'iterations_per_epoch={self.iterations_per_epoch}',
119
        ]
120
        attributes_string = ', '.join(attributes)
121
        return f'Queue({attributes_string})'
122
123
    def _print(self, *args):
124
        if self.verbose:
125
            print(*args)  # noqa: T001
126
127
    @property
128
    def num_subjects(self) -> int:
129
        return len(self.subjects_dataset)
130
131
    @property
132
    def num_patches(self) -> int:
133
        return len(self.patches_list)
134
135
    @property
136
    def iterations_per_epoch(self) -> int:
137
        return self.num_subjects * self.samples_per_volume
138
139
    def fill(self) -> None:
140
        assert self.sampler is not None
141
        if self.max_length % self.samples_per_volume != 0:
142
            message = (
143
                f'Queue length ({self.max_length})'
144
                ' not divisible by the number of'
145
                f' patches per volume ({self.samples_per_volume})'
146
            )
147
            warnings.warn(message)
148
149
        # If there are e.g. 4 subjects and 1 sample per volume and max_length
150
        # is 6, we just need to load 4 subjects, not 6
151
        max_num_subjects_for_queue = self.max_length // self.samples_per_volume
152
        num_subjects_for_queue = min(
153
            self.num_subjects, max_num_subjects_for_queue)
154
155
        self._print(f'Filling queue from {num_subjects_for_queue} subjects...')
156
        if self.verbose:
157
            iterable = trange(num_subjects_for_queue, leave=False)
158
        else:
159
            iterable = range(num_subjects_for_queue)
160
        for _ in iterable:
161
            subject_sample = self.get_next_subject_sample()
162
            iterable = self.sampler(subject_sample)
163
            patches = list(islice(iterable, self.samples_per_volume))
164
            self.patches_list.extend(patches)
165
        if self.shuffle_patches:
166
            random.shuffle(self.patches_list)
167
168
    def get_next_subject_sample(self) -> dict:
169
        # A StopIteration exception is expected when the queue is empty
170
        try:
171
            subject_sample = next(self.subjects_iterable)
172
        except StopIteration as exception:
173
            self._print('Queue is empty:', exception)
174
            self.subjects_iterable = self.get_subjects_iterable()
175
            subject_sample = next(self.subjects_iterable)
176
        return subject_sample
177
178
    def get_subjects_iterable(self) -> Iterator:
179
        # I need a DataLoader to handle parallelism
180
        # But this loader is always expected to yield single subject samples
181
        self._print(
182
            '\nCreating subjects loader with', self.num_workers, 'workers')
183
        subjects_loader = DataLoader(
184
            self.subjects_dataset,
185
            num_workers=self.num_workers,
186
            collate_fn=lambda x: x[0],
187
            shuffle=self.shuffle_subjects,
188
        )
189
        return iter(subjects_loader)
190