Passed
Pull Request — master (#201)
by Fernando
01:01
created

torchio.data.queue.Queue.__init__()   A

Complexity

Conditions 1

Size

Total Lines 22
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 21
nop 9
dl 0
loc 22
rs 9.376
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import random
2
import warnings
3
from itertools import islice
4
from typing import List, Iterator
5
6
from tqdm import trange
7
from torch.utils.data import Dataset, DataLoader
8
9
from .sampler import PatchSampler
10
from .dataset import ImagesDataset
11
12
13
class Queue(Dataset):
14
    r"""Patches queue used for patch-based training.
15
16
    Args:
17
        subjects_dataset: Instance of
18
            :class:`~torchio.data.dataset.ImagesDataset`.
19
        max_length: Maximum number of patches that can be stored in the queue.
20
            Using a large number means that the queue needs to be filled less
21
            often, but more CPU memory is needed to store the patches.
22
        samples_per_volume: Number of patches to extract from each volume.
23
            A small number of patches ensures a large variability in the queue,
24
            but training will be slower.
25
        sampler: A sampler used to extract patches from the volumes.
26
        num_workers: Number of subprocesses to use for data loading
27
            (as in :class:`torch.utils.data.DataLoader`).
28
            ``0`` means that the data will be loaded in the main process.
29
        shuffle_subjects: If ``True``, the subjects dataset is shuffled at the
30
            beginning of each epoch, i.e. when all patches from all subjects
31
            have been processed.
32
        shuffle_patches: If ``True``, patches are shuffled after filling the
33
            queue.
34
        verbose: If ``True``, some debugging messages are printed.
35
36
    This sketch can be used to experiment and understand how the queue works.
37
    In this case, :attr:`shuffle_subjects` is ``False``
38
    and :attr:`shuffle_patches` is ``True``.
39
40
    .. raw:: html
41
42
        <embed>
43
            <iframe style="width: 640px; height: 360px; overflow: hidden;" scrolling="no" frameborder="0" src="https://editor.p5js.org/embed/DZwjZzkkV"></iframe>
44
        </embed>
45
46
    .. note:: :attr:`num_workers` refers to the number of workers used to
47
        load and transform the volumes. Multiprocessing is not needed to pop
48
        patches from the queue.
49
50
    Example:
51
52
    >>> from torch.utils.data import DataLoader
53
    >>> import torchio
54
    >>> patches_queue = torchio.Queue(
55
    ...     subjects_dataset=subjects_dataset,  # instance of torchio.ImagesDataset
56
    ...     max_length=300,
57
    ...     samples_per_volume=10,
58
    ...     patch_size=96,
59
    ...     sampler=,
60
    ...     num_workers=4,
61
    ...     shuffle_subjects=True,
62
    ...     shuffle_patches=True,
63
    ... )
64
    >>> patches_loader = DataLoader(patches_queue, batch_size=4)
65
    >>> num_epochs = 20
66
    >>> for epoch_index in range(num_epochs):
67
    ...     for patches_batch in patches_loader:
68
    ...         inputs = patches_batch['image_name'][torchio.DATA]
69
    ...         targets = patches_batch['targets_name'][torchio.DATA]
70
    ...         logits = model(inputs)  # model is some torch.nn.Module
71
72
    """
73
    def __init__(
74
            self,
75
            subjects_dataset: ImagesDataset,
76
            max_length: int,
77
            samples_per_volume: int,
78
            sampler: PatchSampler,
79
            num_workers: int = 0,
80
            shuffle_subjects: bool = True,
81
            shuffle_patches: bool = True,
82
            verbose: bool = False,
83
            ):
84
        self.subjects_dataset = subjects_dataset
85
        self.max_length = max_length
86
        self.shuffle_subjects = shuffle_subjects
87
        self.shuffle_patches = shuffle_patches
88
        self.samples_per_volume = samples_per_volume
89
        self.sampler = sampler
90
        self.num_workers = num_workers
91
        self.verbose = verbose
92
        self.subjects_iterable = self.get_subjects_iterable()
93
        self.patches_list: List[dict] = []
94
        self.num_sampled_patches = 0
95
96
    def __len__(self):
97
        return self.iterations_per_epoch
98
99
    def __getitem__(self, _):
100
        # There are probably more elegant ways of doing this
101
        if not self.patches_list:
102
            self.print('Patches list is empty.')
103
            self.fill()
104
        sample_patch = self.patches_list.pop()
105
        self.num_sampled_patches += 1
106
        return sample_patch
107
108
    def __repr__(self):
109
        attributes = [
110
            f'max_length={self.max_length}',
111
            f'num_subjects={self.num_subjects}',
112
            f'num_patches={self.num_patches}',
113
            f'samples_per_volume={self.samples_per_volume}',
114
            f'num_sampled_patches={self.num_sampled_patches}',
115
            f'iterations_per_epoch={self.iterations_per_epoch}',
116
        ]
117
        attributes_string = ', '.join(attributes)
118
        return f'Queue({attributes_string})'
119
120
    def print(self, *args):
121
        if self.verbose:
122
            print(*args)
123
124
    @property
125
    def num_subjects(self) -> int:
126
        return len(self.subjects_dataset)
127
128
    @property
129
    def num_patches(self) -> int:
130
        return len(self.patches_list)
131
132
    @property
133
    def iterations_per_epoch(self) -> int:
134
        return self.num_subjects * self.samples_per_volume
135
136
    def fill(self) -> None:
137
        assert self.sampler is not None
138
        if self.max_length % self.samples_per_volume != 0:
139
            message = (
140
                f'Queue length ({self.max_length})'
141
                ' not divisible by the number of'
142
                f' patches per volume ({self.samples_per_volume})'
143
            )
144
            warnings.warn(message)
145
146
        # If there are e.g. 4 subjects and 1 sample per volume and max_length
147
        # is 6, we just need to load 4 subjects, not 6
148
        max_num_subjects_for_queue = self.max_length // self.samples_per_volume
149
        num_subjects_for_queue = min(
150
            self.num_subjects, max_num_subjects_for_queue)
151
152
        self.print(f'Filling queue from {num_subjects_for_queue} subjects...')
153
        if self.verbose:
154
            iterable = trange(num_subjects_for_queue, leave=False)
155
        else:
156
            iterable = range(num_subjects_for_queue)
157
        for _ in iterable:
158
            patches = self.get_patches_from_dataset()
159
            self.patches_list.extend(patches)
160
        if self.shuffle_patches:
161
            random.shuffle(self.patches_list)
162
163
    def get_patches_from_dataset(self) -> dict:
164
        # A StopIteration exception is expected when the queue is empty
165
        try:
166
            subject_sample = next(self.subjects_iterable)
167
        except StopIteration as exception:
168
            self.print('Queue is empty:', exception)
169
            self.subjects_iterable = self.get_subjects_iterable()
170
            subject_sample = next(self.subjects_iterable)
171
        return subject_sample
172
173
    def get_subjects_iterable(self) -> Iterator:
174
        def collate_fn(subjects_list):
175
            generator = self.sampler(subjects_list[0])
176
            return list(islice(generator, self.samples_per_volume))
177
        # I need a DataLoader to handle parallelism
178
        # But this loader is always expected to yield single subject samples
179
        self.print(
180
            '\nCreating subjects loader with', self.num_workers, 'workers')
181
        subjects_loader = DataLoader(
182
            self.subjects_dataset,
183
            num_workers=self.num_workers,
184
            collate_fn=collate_fn,
185
            shuffle=self.shuffle_subjects,
186
        )
187
        return iter(subjects_loader)
188