Passed
Pull Request — master (#175)
by Fernando
01:24
created

LabelSampler.get_probability_map()   A

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 2
dl 0
loc 11
rs 10
c 0
b 0
f 0
1
from .weighted import WeightedSampler
2
3
4
class LabelSampler(WeightedSampler):
5
    r"""Extract random patches with labeled voxels at their center.
6
7
    This sampler yields patches whose center value is greater than 0
8
    in the :py:attr:`label_name`.
9
10
    Args:
11
        patch_size: See :py:class:`~torchio.data.PatchSampler`.
12
        label_name: Name of the label image in the sample that will be used to
13
            generate the sampling probability map.
14
15
    Example:
16
        >>> import torchio
17
        >>> subject = torchio.datasets.Colin27()
18
        >>> subject
19
        Colin27(Keys: ('t1', 'head', 'brain'); images: 3)
20
        >>> sample = torchio.ImagesDataset([subject])[0]
21
        >>> sampler = torchio.data.LabelSampler(64, 'brain')
22
        >>> generator = sampler(sample)
23
        >>> for patch in generator:
24
        ...     print(patch.shape)
25
26
    If you want a specific number of patches from a volume, e.g. 10:
27
28
        >>> generator = sampler(sample, num_patches=10)
29
        >>> for patch in iterator:
30
        ...     print(patch.shape)
31
32
    """
33
    def __init__(self, patch_size, label_name):
34
        super().__init__(patch_size, probability_map=label_name)
35
36
    def get_probability_map(self, sample):
37
        """Return binarized image for sampling."""
38
        if self.probability_map_name in sample:
39
            data = sample[self.probability_map_name].data > 0.5
40
        else:
41
            message = (
42
                f'Image "{self.probability_map_name}"'
43
                f' not found in subject sample: {sample}'
44
            )
45
            raise KeyError(message)
46
        return data
47