| Total Complexity | 2 |
| Total Lines | 26 |
| Duplicated Lines | 0 % |
| Changes | 0 | ||
| 1 | import torch |
||
| 2 | import torchio |
||
| 3 | from torchio.data import WeightedSampler |
||
| 4 | from ...utils import TorchioTestCase |
||
| 5 | |||
| 6 | |||
| 7 | class TestWeightedSampler(TorchioTestCase): |
||
| 8 | """Tests for `WeightedSampler` class.""" |
||
| 9 | |||
| 10 | def test_weighted_sampler(self): |
||
| 11 | sample = self.get_sample((7, 7, 7)) |
||
| 12 | sampler = WeightedSampler(5, 'prob') |
||
| 13 | patch = next(iter(sampler(sample))) |
||
| 14 | self.assertEqual(tuple(patch['index_ini']), (1, 1, 1)) |
||
| 15 | |||
| 16 | def get_sample(self, image_shape): |
||
| 17 | t1 = torch.rand(*image_shape) |
||
| 18 | prob = torch.zeros_like(t1) |
||
| 19 | prob[3, 3, 3] = 1 |
||
| 20 | subject = torchio.Subject( |
||
| 21 | t1=torchio.Image(tensor=t1), |
||
| 22 | prob=torchio.Image(tensor=prob), |
||
| 23 | ) |
||
| 24 | sample = torchio.ImagesDataset([subject])[0] |
||
| 25 | return sample |
||
| 26 |