| Total Complexity | 4 |
| Total Lines | 34 |
| Duplicated Lines | 0 % |
| Changes | 0 | ||
| 1 | import torch |
||
| 2 | from ...data.subject import Subject |
||
| 3 | from ...torchio import TypePatchSize |
||
| 4 | from .sampler import RandomSampler |
||
| 5 | from typing import Optional, Tuple, Generator |
||
| 6 | |||
| 7 | |||
| 8 | class UniformSampler(RandomSampler): |
||
| 9 | """Randomly extract patches from a volume with uniform probability. |
||
| 10 | |||
| 11 | Args: |
||
| 12 | patch_size: See :py:class:`~torchio.data.PatchSampler`. |
||
| 13 | """ |
||
| 14 | def __init__(self, patch_size: TypePatchSize): |
||
| 15 | super().__init__(patch_size) |
||
| 16 | |||
| 17 | def get_probability_map(self, sample: Subject) -> torch.Tensor: |
||
| 18 | return torch.ones(1, *sample.spatial_shape) |
||
| 19 | |||
| 20 | def __call__(self, sample: Subject) -> Generator[Subject, None, None]: |
||
| 21 | |||
| 22 | sample.check_consistent_spatial_shape() |
||
| 23 | |||
| 24 | if np.any(self.patch_size > sample.spatial_shape): |
||
|
|
|||
| 25 | message = ( |
||
| 26 | f'Patch size {tuple(self.patch_size)} cannot be' |
||
| 27 | f' larger than image size {tuple(sample.spatial_shape)}' |
||
| 28 | ) |
||
| 29 | raise RuntimeError(message) |
||
| 30 | |||
| 31 | valid_range = sample.spatial_shape - self.patch_size |
||
| 32 | corners = np.asarray([torch.randint(x+1,(1,)).item() for x in valid_range]) |
||
| 33 | yield self.extract_patch(sample, corners) |
||
| 34 |