| Total Complexity | 2 |
| Total Lines | 26 |
| Duplicated Lines | 0 % |
| Changes | 0 | ||
| 1 | import torch |
||
| 2 | import torchio |
||
| 3 | import numpy as np |
||
| 4 | from torchio.data import UniformSampler |
||
| 5 | from ...utils import TorchioTestCase |
||
| 6 | |||
| 7 | |||
| 8 | class TestUniformSampler(TorchioTestCase): |
||
| 9 | """Tests for `UniformSampler` class.""" |
||
| 10 | |||
| 11 | def test_uniform_probabilities(self): |
||
| 12 | sampler = UniformSampler(5) |
||
| 13 | probabilities = sampler.get_probability_map(self.sample) |
||
| 14 | fixtures = torch.ones_like(probabilities) |
||
| 15 | assert torch.all(probabilities.eq(fixtures)) |
||
| 16 | |||
| 17 | def test_processed_uniform_probabilities(self): |
||
| 18 | sampler = UniformSampler(5) |
||
| 19 | probabilities = sampler.get_probability_map(self.sample) |
||
| 20 | probabilities = sampler.process_probability_map(probabilities) |
||
| 21 | fixtures = np.zeros_like(probabilities) |
||
| 22 | # Other positions cannot be patch centers |
||
| 23 | fixtures[2:-2, 2:-2, 2:-2] = probabilities[2, 2, 2] |
||
| 24 | self.assertAlmostEqual(probabilities.sum(), 1) |
||
| 25 | assert np.equal(probabilities, fixtures).all() |
||
| 26 |