Passed
Push — master ( 672991...862c9a )
by Fernando
03:11 queued 02:03
created

torchio.data.sampler.uniform   A

Complexity

Total Complexity 6

Size/Duplication

Total Lines 44
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 6
eloc 29
dl 0
loc 44
rs 10
c 0
b 0
f 0

2 Methods

Rating   Name   Duplication   Size   Complexity  
A UniformSampler.get_probability_map() 0 2 1
B UniformSampler.__call__() 0 25 5
1
import torch
2
from ...data.subject import Subject
3
from ...typing import TypePatchSize
4
from .sampler import RandomSampler
5
from typing import Generator
6
import numpy as np
7
8
9
class UniformSampler(RandomSampler):
10
    """Randomly extract patches from a volume with uniform probability.
11
12
    Args:
13
        patch_size: See :class:`~torchio.data.PatchSampler`.
14
    """
15
16
    def get_probability_map(self, subject: Subject) -> torch.Tensor:
17
        return torch.ones(1, *subject.spatial_shape)
18
19
    def __call__(
20
            self,
21
            subject: Subject,
22
            num_patches: int = None,
23
            ) -> Generator[Subject, None, None]:
24
        subject.check_consistent_spatial_shape()
25
26
        if np.any(self.patch_size > subject.spatial_shape):
27
            message = (
28
                f'Patch size {tuple(self.patch_size)} cannot be'
29
                f' larger than image size {tuple(subject.spatial_shape)}'
30
            )
31
            raise RuntimeError(message)
32
33
        valid_range = subject.spatial_shape - self.patch_size
34
        patches_left = num_patches if num_patches is not None else True
35
        while patches_left:
36
            index_ini = [
37
                torch.randint(x + 1, (1,)).item()
38
                for x in valid_range
39
            ]
40
            index_ini_array = np.asarray(index_ini)
41
            yield self.extract_patch(subject, index_ini_array)
42
            if num_patches is not None:
43
                patches_left -= 1
44