Passed
Pull Request — master (#403)
by Fernando
01:12
created

UniformSampler.__call__()   B

Complexity

Conditions 5

Size

Total Lines 25
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 19
nop 3
dl 0
loc 25
rs 8.9833
c 0
b 0
f 0
1
import torch
2
from ...data.subject import Subject
3
from ...typing import TypePatchSize
4
from .sampler import RandomSampler
5
from typing import Generator
6
import numpy as np
7
8
9
class UniformSampler(RandomSampler):
10
    """Randomly extract patches from a volume with uniform probability.
11
12
    Args:
13
        patch_size: See :class:`~torchio.data.PatchSampler`.
14
    """
15
    def __init__(self, patch_size: TypePatchSize):
16
        super().__init__(patch_size)
17
18
    def get_probability_map(self, subject: Subject) -> torch.Tensor:
19
        return torch.ones(1, *subject.spatial_shape)
20
21
    def __call__(
22
            self,
23
            subject: Subject,
24
            num_patches: int = None,
25
            ) -> Generator[Subject, None, None]:
26
        subject.check_consistent_spatial_shape()
27
28
        if np.any(self.patch_size > subject.spatial_shape):
29
            message = (
30
                f'Patch size {tuple(self.patch_size)} cannot be'
31
                f' larger than image size {tuple(subject.spatial_shape)}'
32
            )
33
            raise RuntimeError(message)
34
35
        valid_range = subject.spatial_shape - self.patch_size
36
        patches_left = num_patches if num_patches is not None else True
37
        while patches_left:
38
            index_ini = [
39
                torch.randint(x + 1, (1,)).item()
40
                for x in valid_range
41
            ]
42
            index_ini_array = np.asarray(index_ini)
43
            yield self.extract_patch(subject, index_ini_array)
44
            if num_patches is not None:
45
                patches_left -= 1
46