| Total Complexity | 6 |
| Total Lines | 45 |
| Duplicated Lines | 0 % |
| Changes | 0 | ||
| 1 | from torch.utils.data import DataLoader |
||
| 2 | from torchio.data import UniformSampler |
||
| 3 | from torchio import SubjectsDataset, Queue, DATA |
||
| 4 | from torchio.utils import create_dummy_dataset |
||
| 5 | from ..utils import TorchioTestCase |
||
| 6 | |||
| 7 | |||
| 8 | class TestQueue(TorchioTestCase): |
||
| 9 | """Tests for `queue` module.""" |
||
| 10 | def setUp(self): |
||
| 11 | super().setUp() |
||
| 12 | self.subjects_list = create_dummy_dataset( |
||
| 13 | num_images=10, |
||
| 14 | size_range=(10, 20), |
||
| 15 | directory=self.dir, |
||
| 16 | suffix='.nii', |
||
| 17 | force=False, |
||
| 18 | ) |
||
| 19 | |||
| 20 | def run_queue(self, num_workers, **kwargs): |
||
| 21 | subjects_dataset = SubjectsDataset(self.subjects_list) |
||
| 22 | patch_size = 10 |
||
| 23 | sampler = UniformSampler(patch_size) |
||
| 24 | queue_dataset = Queue( |
||
| 25 | subjects_dataset, |
||
| 26 | max_length=6, |
||
| 27 | samples_per_volume=2, |
||
| 28 | sampler=sampler, |
||
| 29 | **kwargs, |
||
| 30 | ) |
||
| 31 | _ = str(queue_dataset) |
||
| 32 | batch_loader = DataLoader(queue_dataset, batch_size=4) |
||
| 33 | for batch in batch_loader: |
||
| 34 | _ = batch['one_modality'][DATA] |
||
| 35 | _ = batch['segmentation'][DATA] |
||
| 36 | |||
| 37 | def test_queue(self): |
||
| 38 | self.run_queue(num_workers=0) |
||
| 39 | |||
| 40 | def test_queue_multiprocessing(self): |
||
| 41 | self.run_queue(num_workers=2) |
||
| 42 | |||
| 43 | def test_queue_no_start_background(self): |
||
| 44 | self.run_queue(num_workers=0, start_background=False) |
||
| 45 |