| Total Complexity | 5 |
| Total Lines | 34 |
| Duplicated Lines | 0 % |
| Changes | 0 | ||
| 1 | import torch |
||
| 2 | from ...data.subject import Subject |
||
| 3 | from .sampler import RandomSampler |
||
| 4 | from typing import Generator |
||
| 5 | import numpy as np |
||
| 6 | |||
| 7 | |||
| 8 | class UniformSampler(RandomSampler): |
||
| 9 | """Randomly extract patches from a volume with uniform probability. |
||
| 10 | |||
| 11 | Args: |
||
| 12 | patch_size: See :class:`~torchio.data.PatchSampler`. |
||
| 13 | """ |
||
| 14 | |||
| 15 | def get_probability_map(self, subject: Subject) -> torch.Tensor: |
||
| 16 | return torch.ones(1, *subject.spatial_shape) |
||
| 17 | |||
| 18 | def _generate_patches( |
||
| 19 | self, |
||
| 20 | subject: Subject, |
||
| 21 | num_patches: int = None, |
||
| 22 | ) -> Generator[Subject, None, None]: |
||
| 23 | valid_range = subject.spatial_shape - self.patch_size |
||
| 24 | patches_left = num_patches if num_patches is not None else True |
||
| 25 | while patches_left: |
||
| 26 | index_ini = [ |
||
| 27 | torch.randint(x + 1, (1,)).item() |
||
| 28 | for x in valid_range |
||
| 29 | ] |
||
| 30 | index_ini_array = np.asarray(index_ini) |
||
| 31 | yield self.extract_patch(subject, index_ini_array) |
||
| 32 | if num_patches is not None: |
||
| 33 | patches_left -= 1 |
||
| 34 |