Passed
Pull Request — master (#175)
by Fernando
58s
created

torchio.data.sampler.label.LabelSampler.extract_patch()   A

Complexity

Conditions 2

Size

Total Lines 14
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 13
nop 1
dl 0
loc 14
rs 9.75
c 0
b 0
f 0

1 Method

Rating   Name   Duplication   Size   Complexity  
A torchio.data.sampler.label.LabelSampler.get_probability_map() 0 11 2
1
import torch
2
from ...data.subject import Subject
3
from ...torchio import TypePatchSize
4
from .weighted import WeightedSampler
5
6
7
class LabelSampler(WeightedSampler):
8
    r"""Extract random patches with labeled voxels at their center.
9
10
    This sampler yields patches whose center value is greater than 0
11
    in the :py:attr:`label_name`.
12
13
    Args:
14
        patch_size: See :py:class:`~torchio.data.PatchSampler`.
15
        label_name: Name of the label image in the sample that will be used to
16
            generate the sampling probability map.
17
18
    Example:
19
        >>> import torchio
20
        >>> subject = torchio.datasets.Colin27()
21
        >>> subject
22
        Colin27(Keys: ('t1', 'head', 'brain'); images: 3)
23
        >>> sample = torchio.ImagesDataset([subject])[0]
24
        >>> sampler = torchio.data.LabelSampler(64, 'brain')
25
        >>> generator = sampler(sample)
26
        >>> for patch in generator:
27
        ...     print(patch.shape)
28
29
    If you want a specific number of patches from a volume, e.g. 10:
30
31
        >>> generator = sampler(sample, num_patches=10)
32
        >>> for patch in iterator:
33
        ...     print(patch.shape)
34
35
    """
36
    def __init__(self, patch_size: TypePatchSize, label_name: str):
37
        super().__init__(patch_size, probability_map=label_name)
38
39
    def get_probability_map(self, sample: Subject) -> torch.Tensor:
40
        """Return binarized image for sampling."""
41
        if self.probability_map_name in sample:
42
            data = sample[self.probability_map_name].data > 0.5
43
        else:
44
            message = (
45
                f'Image "{self.probability_map_name}"'
46
                f' not found in subject sample: {sample}'
47
            )
48
            raise KeyError(message)
49
        return data
50