Total Complexity | 6 |
Total Lines | 45 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | from torch.utils.data import DataLoader |
||
2 | from torchio.data import UniformSampler |
||
3 | from torchio import SubjectsDataset, Queue, DATA |
||
4 | from torchio.utils import create_dummy_dataset |
||
5 | from ..utils import TorchioTestCase |
||
6 | |||
7 | |||
8 | class TestQueue(TorchioTestCase): |
||
9 | """Tests for `queue` module.""" |
||
10 | def setUp(self): |
||
11 | super().setUp() |
||
12 | self.subjects_list = create_dummy_dataset( |
||
13 | num_images=10, |
||
14 | size_range=(10, 20), |
||
15 | directory=self.dir, |
||
16 | suffix='.nii', |
||
17 | force=False, |
||
18 | ) |
||
19 | |||
20 | def run_queue(self, num_workers, **kwargs): |
||
21 | subjects_dataset = SubjectsDataset(self.subjects_list) |
||
22 | patch_size = 10 |
||
23 | sampler = UniformSampler(patch_size) |
||
24 | queue_dataset = Queue( |
||
25 | subjects_dataset, |
||
26 | max_length=6, |
||
27 | samples_per_volume=2, |
||
28 | sampler=sampler, |
||
29 | **kwargs, |
||
30 | ) |
||
31 | _ = str(queue_dataset) |
||
32 | batch_loader = DataLoader(queue_dataset, batch_size=4) |
||
33 | for batch in batch_loader: |
||
34 | _ = batch['one_modality'][DATA] |
||
35 | _ = batch['segmentation'][DATA] |
||
36 | |||
37 | def test_queue(self): |
||
38 | self.run_queue(num_workers=0) |
||
39 | |||
40 | def test_queue_multiprocessing(self): |
||
41 | self.run_queue(num_workers=2) |
||
42 | |||
43 | def test_queue_no_start_background(self): |
||
44 | self.run_queue(num_workers=0, start_background=False) |
||
45 |