Total Complexity | 4 |
Total Lines | 35 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | import torch |
||
2 | from ...data.subject import Subject |
||
3 | from ...torchio import TypePatchSize |
||
4 | from .sampler import RandomSampler |
||
5 | from typing import Optional, Tuple, Generator |
||
6 | import numpy as np |
||
7 | |||
8 | |||
9 | class UniformSampler(RandomSampler): |
||
10 | """Randomly extract patches from a volume with uniform probability. |
||
11 | |||
12 | Args: |
||
13 | patch_size: See :py:class:`~torchio.data.PatchSampler`. |
||
14 | """ |
||
15 | def __init__(self, patch_size: TypePatchSize): |
||
16 | super().__init__(patch_size) |
||
17 | |||
18 | def get_probability_map(self, sample: Subject) -> torch.Tensor: |
||
19 | return torch.ones(1, *sample.spatial_shape) |
||
20 | |||
21 | def __call__(self, sample: Subject) -> Generator[Subject, None, None]: |
||
22 | |||
23 | sample.check_consistent_spatial_shape() |
||
24 | |||
25 | if np.any(self.patch_size > sample.spatial_shape): |
||
26 | message = ( |
||
27 | f'Patch size {tuple(self.patch_size)} cannot be' |
||
28 | f' larger than image size {tuple(sample.spatial_shape)}' |
||
29 | ) |
||
30 | raise RuntimeError(message) |
||
31 | |||
32 | valid_range = sample.spatial_shape - self.patch_size |
||
33 | corners = np.random.randint(valid_range + 1) |
||
34 | yield self.extract_patch(sample, corners) |
||
35 |