Passed
Pull Request — master (#296)
by David
01:15
created

UniformSampler.__call__()   A

Complexity

Conditions 2

Size

Total Lines 14
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 9
nop 2
dl 0
loc 14
rs 9.95
c 0
b 0
f 0
1
import torch
2
from ...data.subject import Subject
3
from ...torchio import TypePatchSize
4
from .sampler import RandomSampler
5
from typing import Optional, Tuple, Generator
6
7
8
class UniformSampler(RandomSampler):
9
    """Randomly extract patches from a volume with uniform probability.
10
11
    Args:
12
        patch_size: See :py:class:`~torchio.data.PatchSampler`.
13
    """
14
    def __init__(self, patch_size: TypePatchSize):
15
        super().__init__(patch_size)
16
17
    def get_probability_map(self, sample: Subject) -> torch.Tensor:
18
        return torch.ones(1, *sample.spatial_shape)
19
20
    def __call__(self, sample: Subject) -> Generator[Subject, None, None]:
21
22
        sample.check_consistent_spatial_shape()
23
24
        if np.any(self.patch_size > sample.spatial_shape):
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable np does not seem to be defined.
Loading history...
25
            message = (
26
                f'Patch size {tuple(self.patch_size)} cannot be'
27
                f' larger than image size {tuple(sample.spatial_shape)}'
28
            )
29
            raise RuntimeError(message)
30
31
        valid_range = sample.spatial_shape - self.patch_size
32
        corners = np.asarray([torch.randint(x+1,(1,)).item() for x in valid_range])
33
        yield self.extract_patch(sample, corners)
34