Passed
Push — master ( fca4b7...25d6cf )
by Fernando
01:33
created

torchio.data.sampler.label   A

Complexity

Total Complexity 10

Size/Duplication

Total Lines 96
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 47
dl 0
loc 96
rs 10
c 0
b 0
f 0
wmc 10

3 Methods

Rating   Name   Duplication   Size   Complexity  
B LabelSampler.get_probability_map() 0 21 6
A LabelSampler.get_probabilities_from_label_map() 0 17 3
A LabelSampler.__init__() 0 8 1
1
from typing import Dict, Optional
2
3
import torch
4
5
from ...data.subject import Subject
6
from ...torchio import TypePatchSize, DATA, TYPE, LABEL
7
from .weighted import WeightedSampler
8
9
10
class LabelSampler(WeightedSampler):
11
    r"""Extract random patches with labeled voxels at their center.
12
13
    This sampler yields patches whose center value is greater than 0
14
    in the :py:attr:`label_name`.
15
16
    Args:
17
        patch_size: See :py:class:`~torchio.data.PatchSampler`.
18
        label_name: Name of the label image in the sample that will be used to
19
            generate the sampling probability map. If ``None``, the first image
20
            of type :py:attr:`torchio.LABEL` found in the subject sample will be
21
            used.
22
        label_probabilities: Dictionary containing the probability that each
23
            class will be sampled. Probabilities do not need to be normalized.
24
            For example, a value of ``{0: 0, 1: 2, 2: 1, 3: 1}`` will create a
25
            sampler whose patches centers will have 50% probability of being
26
            labeled as ``1``, 25% of being ``2`` and 25% of being ``3``.
27
            If ``None``, the label map is binarized and the value is set to
28
            ``{0: 0, 1: 1}``.
29
30
    Example:
31
        >>> import torchio
32
        >>> subject = torchio.datasets.Colin27()
33
        >>> subject
34
        Colin27(Keys: ('t1', 'head', 'brain'); images: 3)
35
        >>> sample = torchio.ImagesDataset([subject])[0]
36
        >>> sampler = torchio.data.LabelSampler(64, 'brain')
37
        >>> generator = sampler(sample)
38
        >>> for patch in generator:
39
        ...     print(patch.shape)
40
41
    If you want a specific number of patches from a volume, e.g. 10:
42
43
        >>> generator = sampler(sample, num_patches=10)
44
        >>> for patch in iterator:
45
        ...     print(patch.shape)
46
47
    """
48
    def __init__(
49
            self,
50
            patch_size: TypePatchSize,
51
            label_name: Optional[str] = None,
52
            label_probabilities: Optional[Dict[int, float]] = None,
53
        ):
54
        super().__init__(patch_size, probability_map=label_name)
55
        self.label_probabilities_dict = label_probabilities
56
57
    def get_probability_map(self, sample: Subject) -> torch.Tensor:
58
        if self.probability_map_name is None:
59
            for image in sample.get_images(intensity_only=False):
60
                if image[TYPE] == LABEL:
61
                    label_map_tensor = image[DATA]
62
                    break
63
        elif self.probability_map_name in sample:
64
            label_map_tensor = sample[self.probability_map_name][DATA]
65
        else:
66
            message = (
67
                f'Image "{self.probability_map_name}"'
68
                f' not found in subject sample: {sample}'
69
            )
70
            raise KeyError(message)
71
        if self.label_probabilities_dict is None:
72
            return label_map_tensor > 0
0 ignored issues
show
introduced by
The variable label_map_tensor does not seem to be defined for all execution paths.
Loading history...
73
        probability_map = self.get_probabilities_from_label_map(
74
            label_map_tensor,
75
            self.label_probabilities_dict,
76
        )
77
        return probability_map
78
79
    @staticmethod
80
    def get_probabilities_from_label_map(
81
            label_map: torch.Tensor,
82
            label_probabilities_dict: Dict[int, float],
83
            ) -> torch.Tensor:
84
        """Create probability map according to label map probabilities."""
85
        probability_map = torch.zeros_like(label_map)
86
        label_probs = torch.Tensor(list(label_probabilities_dict.values()))
87
        normalized_probs = label_probs / label_probs.sum()
88
        iterable = zip(label_probabilities_dict, normalized_probs)
89
        for label, label_probability in iterable:
90
            mask = label_map == label
91
            label_size = mask.sum()
92
            if not label_size: continue
93
            prob_voxels = label_probability / label_size
94
            probability_map[mask] = prob_voxels
95
        return probability_map
96