Conditions | 5 |
Total Lines | 25 |
Code Lines | 19 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
1 | import torch |
||
21 | def __call__( |
||
22 | self, |
||
23 | subject: Subject, |
||
24 | num_patches: int = None, |
||
25 | ) -> Generator[Subject, None, None]: |
||
26 | subject.check_consistent_spatial_shape() |
||
27 | |||
28 | if np.any(self.patch_size > subject.spatial_shape): |
||
29 | message = ( |
||
30 | f'Patch size {tuple(self.patch_size)} cannot be' |
||
31 | f' larger than image size {tuple(subject.spatial_shape)}' |
||
32 | ) |
||
33 | raise RuntimeError(message) |
||
34 | |||
35 | valid_range = subject.spatial_shape - self.patch_size |
||
36 | patches_left = num_patches if num_patches is not None else True |
||
37 | while patches_left: |
||
38 | index_ini = [ |
||
39 | torch.randint(x + 1, (1,)).item() |
||
40 | for x in valid_range |
||
41 | ] |
||
42 | index_ini_array = np.asarray(index_ini) |
||
43 | yield self.extract_patch(subject, index_ini_array) |
||
44 | if num_patches is not None: |
||
45 | patches_left -= 1 |
||
46 |