Passed
Push — master ( b06930...3ddbe5 )
by Fernando
03:55
created

torchio.data.queue.Queue.__init__()   A

Complexity

Conditions 1

Size

Total Lines 22
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 21
nop 9
dl 0
loc 22
rs 9.376
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import random
2
import warnings
3
from itertools import islice
4
from typing import List, Iterator
5
6
from tqdm import trange
7
from torch.utils.data import Dataset, DataLoader
8
9
from .subject import Subject
10
from .sampler import PatchSampler
11
from .dataset import SubjectsDataset
12
13
14
class Queue(Dataset):
15
    r"""Patches queue used for patch-based training.
16
17
    Args:
18
        subjects_dataset: Instance of
19
            :class:`~torchio.data.dataset.SubjectsDataset`.
20
        max_length: Maximum number of patches that can be stored in the queue.
21
            Using a large number means that the queue needs to be filled less
22
            often, but more CPU memory is needed to store the patches.
23
        samples_per_volume: Number of patches to extract from each volume.
24
            A small number of patches ensures a large variability in the queue,
25
            but training will be slower.
26
        sampler: A sampler used to extract patches from the volumes.
27
        num_workers: Number of subprocesses to use for data loading
28
            (as in :class:`torch.utils.data.DataLoader`).
29
            ``0`` means that the data will be loaded in the main process.
30
        shuffle_subjects: If ``True``, the subjects dataset is shuffled at the
31
            beginning of each epoch, i.e. when all patches from all subjects
32
            have been processed.
33
        shuffle_patches: If ``True``, patches are shuffled after filling the
34
            queue.
35
        verbose: If ``True``, some debugging messages are printed.
36
37
    This sketch can be used to experiment and understand how the queue works.
38
    In this case, :attr:`shuffle_subjects` is ``False``
39
    and :attr:`shuffle_patches` is ``True``.
40
41
    .. raw:: html
42
43
        <embed>
44
            <iframe style="width: 640px; height: 360px; overflow: hidden;" scrolling="no" frameborder="0" src="https://editor.p5js.org/embed/DZwjZzkkV"></iframe>
45
        </embed>
46
47
    .. note:: :attr:`num_workers` refers to the number of workers used to
48
        load and transform the volumes. Multiprocessing is not needed to pop
49
        patches from the queue.
50
51
    Example:
52
53
    >>> import torch
54
    >>> import torchio as tio
55
    >>> from torch.utils.data import DataLoader
56
    >>> patch_size = 96
57
    >>> queue_length = 300
58
    >>> samples_per_volume = 10
59
    >>> sampler = tio.data.UniformSampler(patch_size)
60
    >>> subject = tio.datasets.Colin27()
61
    >>> subjects_dataset = tio.SubjectsDataset(10 * [subject])
62
    >>> patches_queue = tio.Queue(
63
    ...     subjects_dataset,
64
    ...     queue_length,
65
    ...     samples_per_volume,
66
    ...     sampler,
67
    ...     num_workers=4,
68
    ... )
69
    >>> patches_loader = DataLoader(patches_queue, batch_size=16)
70
    >>> num_epochs = 2
71
    >>> model = torch.nn.Identity()
72
    >>> for epoch_index in range(num_epochs):
73
    ...     for patches_batch in patches_loader:
74
    ...         inputs = patches_batch['t1'][tio.DATA]  # key 't1' is in subject
75
    ...         targets = patches_batch['brain'][tio.DATA]  # key 'brain' is in subject
76
    ...         logits = model(inputs)  # model being an instance of torch.nn.Module
77
78
    """
79
    def __init__(
80
            self,
81
            subjects_dataset: SubjectsDataset,
82
            max_length: int,
83
            samples_per_volume: int,
84
            sampler: PatchSampler,
85
            num_workers: int = 0,
86
            shuffle_subjects: bool = True,
87
            shuffle_patches: bool = True,
88
            verbose: bool = False,
89
            ):
90
        self.subjects_dataset = subjects_dataset
91
        self.max_length = max_length
92
        self.shuffle_subjects = shuffle_subjects
93
        self.shuffle_patches = shuffle_patches
94
        self.samples_per_volume = samples_per_volume
95
        self.sampler = sampler
96
        self.num_workers = num_workers
97
        self.verbose = verbose
98
        self.subjects_iterable = self.get_subjects_iterable()
99
        self.patches_list: List[dict] = []
100
        self.num_sampled_patches = 0
101
102
    def __len__(self):
103
        return self.iterations_per_epoch
104
105
    def __getitem__(self, _):
106
        # There are probably more elegant ways of doing this
107
        if not self.patches_list:
108
            self._print('Patches list is empty.')
109
            self.fill()
110
        sample_patch = self.patches_list.pop()
111
        self.num_sampled_patches += 1
112
        return sample_patch
113
114
    def __repr__(self):
115
        attributes = [
116
            f'max_length={self.max_length}',
117
            f'num_subjects={self.num_subjects}',
118
            f'num_patches={self.num_patches}',
119
            f'samples_per_volume={self.samples_per_volume}',
120
            f'num_sampled_patches={self.num_sampled_patches}',
121
            f'iterations_per_epoch={self.iterations_per_epoch}',
122
        ]
123
        attributes_string = ', '.join(attributes)
124
        return f'Queue({attributes_string})'
125
126
    def _print(self, *args):
127
        if self.verbose:
128
            print(*args)  # noqa: T001
129
130
    @property
131
    def num_subjects(self) -> int:
132
        return len(self.subjects_dataset)
133
134
    @property
135
    def num_patches(self) -> int:
136
        return len(self.patches_list)
137
138
    @property
139
    def iterations_per_epoch(self) -> int:
140
        return self.num_subjects * self.samples_per_volume
141
142
    def fill(self) -> None:
143
        assert self.sampler is not None
144
        if self.max_length % self.samples_per_volume != 0:
145
            message = (
146
                f'Queue length ({self.max_length})'
147
                ' not divisible by the number of'
148
                f' patches per volume ({self.samples_per_volume})'
149
            )
150
            warnings.warn(message)
151
152
        # If there are e.g. 4 subjects and 1 sample per volume and max_length
153
        # is 6, we just need to load 4 subjects, not 6
154
        max_num_subjects_for_queue = self.max_length // self.samples_per_volume
155
        num_subjects_for_queue = min(
156
            self.num_subjects, max_num_subjects_for_queue)
157
158
        self._print(f'Filling queue from {num_subjects_for_queue} subjects...')
159
        if self.verbose:
160
            iterable = trange(num_subjects_for_queue, leave=False)
161
        else:
162
            iterable = range(num_subjects_for_queue)
163
        for _ in iterable:
164
            subject = self.get_next_subject()
165
            iterable = self.sampler(subject)
166
            patches = list(islice(iterable, self.samples_per_volume))
167
            self.patches_list.extend(patches)
168
        if self.shuffle_patches:
169
            random.shuffle(self.patches_list)
170
171
    def get_next_subject(self) -> Subject:
172
        # A StopIteration exception is expected when the queue is empty
173
        try:
174
            subject = next(self.subjects_iterable)
175
        except StopIteration as exception:
176
            self._print('Queue is empty:', exception)
177
            self.subjects_iterable = self.get_subjects_iterable()
178
            subject = next(self.subjects_iterable)
179
        return subject
180
181
    def get_subjects_iterable(self) -> Iterator:
182
        # I need a DataLoader to handle parallelism
183
        # But this loader is always expected to yield single subject samples
184
        self._print(
185
            '\nCreating subjects loader with', self.num_workers, 'workers')
186
        subjects_loader = DataLoader(
187
            self.subjects_dataset,
188
            num_workers=self.num_workers,
189
            collate_fn=lambda x: x[0],
190
            shuffle=self.shuffle_subjects,
191
        )
192
        return iter(subjects_loader)
193