Passed
Pull Request — master (#296)
by Fernando
01:09
created

torchio.data.sampler.uniform   A

Complexity

Total Complexity 4

Size/Duplication

Total Lines 35
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 22
dl 0
loc 35
rs 10
c 0
b 0
f 0
wmc 4

3 Methods

Rating   Name   Duplication   Size   Complexity  
A UniformSampler.__init__() 0 2 1
A UniformSampler.get_probability_map() 0 2 1
A UniformSampler.__call__() 0 14 2
1
import torch
2
from ...data.subject import Subject
3
from ...torchio import TypePatchSize
4
from .sampler import RandomSampler
5
from typing import Optional, Tuple, Generator
6
import numpy as np
7
8
9
class UniformSampler(RandomSampler):
10
    """Randomly extract patches from a volume with uniform probability.
11
12
    Args:
13
        patch_size: See :py:class:`~torchio.data.PatchSampler`.
14
    """
15
    def __init__(self, patch_size: TypePatchSize):
16
        super().__init__(patch_size, None)
17
18
    def get_probability_map(self, sample: Subject) -> torch.Tensor:
19
        return torch.ones(1, *sample.spatial_shape)
20
21
    def __call__(self, sample: Subject) -> Generator[Subject, None, None]:
22
        sample.check_consistent_spatial_shape()
23
24
        if np.any(self.patch_size > sample.spatial_shape):
25
            message = (
26
                f'Patch size {tuple(self.patch_size)} cannot be'
27
                f' larger than image size {tuple(sample.spatial_shape)}'
28
            )
29
            raise RuntimeError(message)
30
31
        valid_range = sample.spatial_shape - self.patch_size
32
        index_ini = [torch.randint(x + 1, (1,)).item() for x in valid_range]
33
        index_ini_array = np.asarray(index_ini)
34
        yield self.extract_patch(sample, index_ini_array)
35