Passed
Pull Request — master (#175)
by Fernando
01:15
created

LabelSampler.get_probability_map()   A

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 2
dl 0
loc 11
rs 10
c 0
b 0
f 0
1
from .weighted import WeightedSampler
2
3
4
class LabelSampler(WeightedSampler):
5
    r"""Extract random patches containing labeled voxels.
6
7
    This iterable dataset yields patches whose center value is greater than 0
8
    in the :py:attr:`label_name`.
9
10
    Args:
11
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
12
            of size :math:`d \times h \times w`.
13
            If a single number :math:`n` is provided,
14
            :math:`d = h = w = n`.
15
        label_name: Name of the label image in the sample that will be used to
16
            generate the sampling probability map.
17
    """
18
    def __init__(self, patch_size, label_name):
19
        super().__init__(patch_size, probability_map=label_name)
20
21
    def get_probability_map(self, sample):
22
        """Return binarized image for sampling."""
23
        if self.probability_map_name in sample:
24
            data = sample[self.probability_map_name].data > 0.5
25
        else:
26
            message = (
27
                f'Image "{self.probability_map_name}"'
28
                f' not found in subject sample: {sample}'
29
            )
30
            raise KeyError(message)
31
        return data
32