Passed
Push — master ( 9bf301...e8ad0b )
by Fernando
01:21
created

LabelSampler.get_probabilities_from_label_map()   B

Complexity

Conditions 6

Size

Total Lines 28
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 24
nop 2
dl 0
loc 28
rs 8.3706
c 0
b 0
f 0
1
from typing import Dict, Optional
2
3
import torch
4
5
from ...data.subject import Subject
6
from ...torchio import TypePatchSize, DATA, TYPE, LABEL
7
from .weighted import WeightedSampler
8
9
10
class LabelSampler(WeightedSampler):
11
    r"""Extract random patches with labeled voxels at their center.
12
13
    This sampler yields patches whose center value is greater than 0
14
    in the :py:attr:`label_name`.
15
16
    Args:
17
        patch_size: See :py:class:`~torchio.data.PatchSampler`.
18
        label_name: Name of the label image in the sample that will be used to
19
            generate the sampling probability map. If ``None``, the first image
20
            of type :py:attr:`torchio.LABEL` found in the subject sample will be
21
            used.
22
        label_probabilities: Dictionary containing the probability that each
23
            class will be sampled. Probabilities do not need to be normalized.
24
            For example, a value of ``{0: 0, 1: 2, 2: 1, 3: 1}`` will create a
25
            sampler whose patches centers will have 50% probability of being
26
            labeled as ``1``, 25% of being ``2`` and 25% of being ``3``.
27
            If ``None``, the label map is binarized and the value is set to
28
            ``{0: 0, 1: 1}``.
29
            If the input has multiple channels, a value of
30
            ``{0: 0, 1: 2, 2: 1, 3: 1}`` will create a
31
            sampler whose patches centers will have 50% probability of being
32
            taken from a non zero value of channel ``1``, 25% from channel
33
            ``2`` and 25% from channel ``3``.
34
35
    Example:
36
        >>> import torchio
37
        >>> subject = torchio.datasets.Colin27()
38
        >>> subject
39
        Colin27(Keys: ('t1', 'head', 'brain'); images: 3)
40
        >>> sample = torchio.SubjectsDataset([subject])[0]
41
        >>> sampler = torchio.data.LabelSampler(64, 'brain')
42
        >>> generator = sampler(sample)
43
        >>> for patch in generator:
44
        ...     print(patch.shape)
45
46
    If you want a specific number of patches from a volume, e.g. 10:
47
48
        >>> generator = sampler(sample, num_patches=10)
49
        >>> for patch in iterator:
50
        ...     print(patch.shape)
51
52
    """
53
    def __init__(
54
            self,
55
            patch_size: TypePatchSize,
56
            label_name: Optional[str] = None,
57
            label_probabilities: Optional[Dict[int, float]] = None,
58
        ):
59
        super().__init__(patch_size, probability_map=label_name)
60
        self.label_probabilities_dict = label_probabilities
61
62
    def get_probability_map(self, sample: Subject) -> torch.Tensor:
63
        if self.probability_map_name is None:
64
            for image in sample.get_images(intensity_only=False):
65
                if image[TYPE] == LABEL:
66
                    label_map_tensor = image[DATA]
67
                    break
68
        elif self.probability_map_name in sample:
69
            label_map_tensor = sample[self.probability_map_name][DATA]
70
        else:
71
            message = (
72
                f'Image "{self.probability_map_name}"'
73
                f' not found in subject sample: {sample}'
74
            )
75
            raise KeyError(message)
76
        if self.label_probabilities_dict is None:
77
            return label_map_tensor > 0
0 ignored issues
show
introduced by
The variable label_map_tensor does not seem to be defined for all execution paths.
Loading history...
78
        probability_map = self.get_probabilities_from_label_map(
79
            label_map_tensor,
80
            self.label_probabilities_dict,
81
        )
82
        return probability_map
83
84
    @staticmethod
85
    def get_probabilities_from_label_map(
86
            label_map: torch.Tensor,
87
            label_probabilities_dict: Dict[int, float],
88
            ) -> torch.Tensor:
89
        """Create probability map according to label map probabilities."""
90
        multichannel = label_map.shape[0] > 1
91
        probability_map = torch.zeros_like(label_map)
92
        label_probs = torch.Tensor(list(label_probabilities_dict.values()))
93
        normalized_probs = label_probs / label_probs.sum()
94
        iterable = zip(label_probabilities_dict, normalized_probs)
95
        for label, label_probability in iterable:
96
            if multichannel:
97
                mask = label_map[label]
98
            else:
99
                mask = label_map == label
100
            label_size = mask.sum()
101
            if not label_size:
102
                continue
103
            prob_voxels = label_probability / label_size
104
            probability_map[mask] = prob_voxels
105
            if multichannel:
106
                probability_map[label] = prob_voxels * mask
107
            else:
108
                probability_map[mask] = prob_voxels
109
        if multichannel:
110
            probability_map = probability_map.sum(dim=0, keepdim=True)
111
        return probability_map
112