Passed
Push — master ( 7b848f...cac223 )
by Fernando
02:39
created

torchio.data.queue.Queue.__init__()   A

Complexity

Conditions 1

Size

Total Lines 24
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 23
nop 10
dl 0
loc 24
rs 9.328
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import random
2
import warnings
3
from typing import List, Iterator
4
from itertools import islice
5
from tqdm import trange
6
from torch.utils.data import Dataset, DataLoader
7
from .. import TypeTuple
8
from .dataset import ImagesDataset
9
from .sampler import ImageSampler
10
11
12
class Queue(Dataset):
13
    r"""Patches queue used for patch-based training.
14
15
    Args:
16
        subjects_dataset: Instance of
17
            :class:`~torchio.data.dataset.ImagesDataset`.
18
        max_length: Maximum number of patches that can be stored in the queue.
19
            Using a large number means that the queue needs to be filled less
20
            often, but more RAM is needed to store the patches.
21
        samples_per_volume: Number of patches to extract from each volume.
22
            A small number of patches ensures a large variability in the queue,
23
            but training will be slower.
24
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
25
            of size :math:`d \times h \times w`.
26
            If a single number :math:`n` is provided,
27
            :math:`d = h = w = n`.
28
        sampler_class: An instance of :class:`~torchio.data.datasetampler` used
29
            to define the patches sampling strategy.
30
        num_workers: Number of subprocesses to use for data loading
31
            (as in :class:`torch.utils.data.DataLoader`).
32
            ``0`` means that the data will be loaded in the main process.
33
        shuffle_subjects: If ``True``, the subjects dataset is shuffled at the
34
            beginning of each epoch, i.e. when all patches from all subjects
35
            have been processed.
36
        shuffle_patches: If ``True``, patches are shuffled after filling the
37
            queue.
38
        verbose: If ``True``, some debugging messages are printed.
39
40
    .. note:: :attr:`num_workers` refers to the number of workers used to
41
        load and transform the volumes. Multiprocessing is not needed to pop
42
        patches from the queue.
43
44
    Example:
45
46
    >>> from torch.utils.data import DataLoader
47
    >>> import torchio
48
    >>> patches_queue = torchio.Queue(
49
    ...     subjects_dataset=subjects_dataset,  # instance of torchio.ImagesDataset
50
    ...     max_length=300,
51
    ...     samples_per_volume=10,
52
    ...     patch_size=96,
53
    ...     sampler_class=torchio.sampler.ImageSampler,
54
    ...     num_workers=4,
55
    ...     shuffle_subjects=True,
56
    ...     shuffle_patches=True,
57
    ... )
58
    >>> patches_loader = DataLoader(patches_queue, batch_size=4)
59
    >>> num_epochs = 20
60
    >>> for epoch_index in range(num_epochs):
61
    ...     for patches_batch in patches_loader:
62
    ...         inputs = patches_batch['image_name'][torchio.DATA]
63
    ...         targets = patches_batch['targets_name'][torchio.DATA]
64
    ...         logits = model(inputs)  # model is some torch.nn.Module
65
66
    """
67
    def __init__(
68
            self,
69
            subjects_dataset: ImagesDataset,
70
            max_length: int,
71
            samples_per_volume: int,
72
            patch_size: TypeTuple,
73
            sampler_class: ImageSampler,
74
            num_workers: int = 0,
75
            shuffle_subjects: bool = True,
76
            shuffle_patches: bool = True,
77
            verbose: bool = False,
78
            ):
79
        self.subjects_dataset = subjects_dataset
80
        self.max_length = max_length
81
        self.shuffle_subjects = shuffle_subjects
82
        self.shuffle_patches = shuffle_patches
83
        self.samples_per_volume = samples_per_volume
84
        self.sampler_class = sampler_class
85
        self.patch_size = patch_size
86
        self.num_workers = num_workers
87
        self.verbose = verbose
88
        self.subjects_iterable = self.get_subjects_iterable()
89
        self.patches_list: List[dict] = []
90
        self.num_sampled_patches = 0
91
92
    def __len__(self):
93
        return self.iterations_per_epoch
94
95
    def __getitem__(self, _):
96
        # There are probably more elegant ways of doing this
97
        if not self.patches_list:
98
            self.print('Patches list is empty.')
99
            self.fill()
100
        sample_patch = self.patches_list.pop()
101
        self.num_sampled_patches += 1
102
        return sample_patch
103
104
    def __repr__(self):
105
        attributes = [
106
            f'max_length={self.max_length}',
107
            f'num_subjects={self.num_subjects}',
108
            f'num_patches={self.num_patches}',
109
            f'samples_per_volume={self.samples_per_volume}',
110
            f'num_sampled_patches={self.num_sampled_patches}',
111
            f'iterations_per_epoch={self.iterations_per_epoch}',
112
        ]
113
        attributes_string = ', '.join(attributes)
114
        return f'Queue({attributes_string})'
115
116
    def print(self, *args):
117
        if self.verbose:
118
            print(*args)
119
120
    @property
121
    def num_subjects(self) -> int:
122
        return len(self.subjects_dataset)
123
124
    @property
125
    def num_patches(self) -> int:
126
        return len(self.patches_list)
127
128
    @property
129
    def iterations_per_epoch(self) -> int:
130
        return self.num_subjects * self.samples_per_volume
131
132
    def fill(self) -> None:
133
        assert self.sampler_class is not None
134
        assert self.patch_size is not None
135
        if self.max_length % self.samples_per_volume != 0:
136
            message = (
137
                f'Queue length ({self.max_length})'
138
                ' not divisible by the number of'
139
                f' patches per volume ({self.samples_per_volume})'
140
            )
141
            warnings.warn(message)
142
143
        # If there are e.g. 4 subjects and 1 sample per volume and max_length
144
        # is 6, we just need to load 4 subjects, not 6
145
        max_num_subjects_for_queue = self.max_length // self.samples_per_volume
146
        num_subjects_for_queue = min(
147
            self.num_subjects, max_num_subjects_for_queue)
148
149
        self.print(f'Filling queue from {num_subjects_for_queue} subjects...')
150
        if self.verbose:
151
            iterable = trange(num_subjects_for_queue, leave=False)
152
        else:
153
            iterable = range(num_subjects_for_queue)
154
        for _ in iterable:
155
            subject_sample = self.get_next_subject_sample()
156
            sampler = self.sampler_class(subject_sample, self.patch_size)
157
            samples = list(islice(sampler, self.samples_per_volume))
158
            self.patches_list.extend(samples)
159
        if self.shuffle_patches:
160
            random.shuffle(self.patches_list)
161
162
    def get_next_subject_sample(self) -> dict:
163
        # A StopIteration exception is expected when the queue is empty
164
        try:
165
            subject_sample = next(self.subjects_iterable)
166
        except StopIteration as exception:
167
            self.print('Queue is empty:', exception)
168
            self.subjects_iterable = self.get_subjects_iterable()
169
            subject_sample = next(self.subjects_iterable)
170
        return subject_sample
171
172
    def get_subjects_iterable(self) -> Iterator:
173
        # I need a DataLoader to handle parallelism
174
        # But this loader is always expected to yield single subject samples
175
        self.print(
176
            '\nCreating subjects loader with', self.num_workers, 'workers')
177
        subjects_loader = DataLoader(
178
            self.subjects_dataset,
179
            num_workers=self.num_workers,
180
            collate_fn=lambda x: x[0],
181
            shuffle=self.shuffle_subjects,
182
        )
183
        return iter(subjects_loader)
184