| Total Complexity | 2 |
| Total Lines | 13 |
| Duplicated Lines | 0 % |
| Changes | 0 | ||
| 1 | from torchio import DATA |
||
| 2 | from torchio.data import LabelSampler |
||
| 3 | from ...utils import TorchioTestCase |
||
| 4 | |||
| 5 | |||
| 6 | class TestLabelSampler(TorchioTestCase): |
||
| 7 | """Tests for `LabelSampler` class.""" |
||
| 8 | |||
| 9 | def test_label_sampler(self): |
||
| 10 | sampler = LabelSampler(5, 'label') |
||
| 11 | for patch in sampler(self.sample, num_patches=10): |
||
| 12 | self.assertEqual(patch['label'][DATA][0, 2, 2, 2], 1) |
||
| 13 |