Passed
Pull Request — master (#386)
by Fernando
01:21
created

LabelSampler.get_probability_map_image()   A

Complexity

Conditions 5

Size

Total Lines 15
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 12
nop 2
dl 0
loc 15
rs 9.3333
c 0
b 0
f 0
1
from typing import Dict, Optional
2
3
import torch
4
5
from ...data.image import LabelMap
6
from ...data.subject import Subject
7
from ...typing import TypePatchSize
8
from ...constants import TYPE, LABEL
9
from .weighted import WeightedSampler
10
11
12
class LabelSampler(WeightedSampler):
13
    r"""Extract random patches with labeled voxels at their center.
14
15
    This sampler yields patches whose center value is greater than 0
16
    in the :attr:`label_name`.
17
18
    Args:
19
        patch_size: See :class:`~torchio.data.PatchSampler`.
20
        label_name: Name of the label image in the subject that will be used to
21
            generate the sampling probability map. If ``None``, the first image
22
            of type :attr:`torchio.LABEL` found in the subject subject will be
23
            used.
24
        label_probabilities: Dictionary containing the probability that each
25
            class will be sampled. Probabilities do not need to be normalized.
26
            For example, a value of ``{0: 0, 1: 2, 2: 1, 3: 1}`` will create a
27
            sampler whose patches centers will have 50% probability of being
28
            labeled as ``1``, 25% of being ``2`` and 25% of being ``3``.
29
            If ``None``, the label map is binarized and the value is set to
30
            ``{0: 0, 1: 1}``.
31
            If the input has multiple channels, a value of
32
            ``{0: 0, 1: 2, 2: 1, 3: 1}`` will create a
33
            sampler whose patches centers will have 50% probability of being
34
            taken from a non zero value of channel ``1``, 25% from channel
35
            ``2`` and 25% from channel ``3``.
36
37
    Example:
38
        >>> import torchio as tio
39
        >>> subject = tio.datasets.Colin27()
40
        >>> subject
41
        Colin27(Keys: ('t1', 'head', 'brain'); images: 3)
42
        >>> subject = tio.SubjectsDataset([subject])[0]
43
        >>> sampler = tio.data.LabelSampler(64, 'brain')
44
        >>> generator = sampler(subject)
45
        >>> for patch in generator:
46
        ...     print(patch.shape)
47
48
    If you want a specific number of patches from a volume, e.g. 10:
49
50
        >>> generator = sampler(subject, num_patches=10)
51
        >>> for patch in iterator:
52
        ...     print(patch.shape)
53
54
    """
55
    def __init__(
56
            self,
57
            patch_size: TypePatchSize,
58
            label_name: Optional[str] = None,
59
            label_probabilities: Optional[Dict[int, float]] = None,
60
            ):
61
        super().__init__(patch_size, probability_map=label_name)
62
        self.label_probabilities_dict = label_probabilities
63
64
    def get_probability_map_image(self, subject: Subject) -> LabelMap:
65
        if self.probability_map_name is None:
66
            for image in subject.get_images(intensity_only=False):
67
                if image[TYPE] == LABEL:
68
                    label_map = image
69
                    break
70
        elif self.probability_map_name in subject:
71
            label_map = subject[self.probability_map_name]
72
        else:
73
            message = (
74
                f'Image "{self.probability_map_name}"'
75
                f' not found in subject subject: {subject}'
76
            )
77
            raise KeyError(message)
78
        return label_map
0 ignored issues
show
introduced by
The variable label_map does not seem to be defined for all execution paths.
Loading history...
79
80
    def get_probability_map(self, subject: Subject) -> torch.Tensor:
81
        label_map_tensor = self.get_probability_map_image(subject).data
82
        label_map_tensor = label_map_tensor.float()
83
84
        if self.label_probabilities_dict is None:
85
            return label_map_tensor > 0
86
        probability_map = self.get_probabilities_from_label_map(
87
            label_map_tensor,
88
            self.label_probabilities_dict,
89
        )
90
        return probability_map
91
92
    @staticmethod
93
    def get_probabilities_from_label_map(
94
            label_map: torch.Tensor,
95
            label_probabilities_dict: Dict[int, float],
96
            ) -> torch.Tensor:
97
        """Create probability map according to label map probabilities."""
98
        multichannel = label_map.shape[0] > 1
99
        probability_map = torch.zeros_like(label_map)
100
        label_probs = torch.Tensor(list(label_probabilities_dict.values()))
101
        normalized_probs = label_probs / label_probs.sum()
102
        iterable = zip(label_probabilities_dict, normalized_probs)
103
        for label, label_probability in iterable:
104
            if multichannel:
105
                mask = label_map[label]
106
            else:
107
                mask = label_map == label
108
            label_size = mask.sum()
109
            if not label_size:
110
                continue
111
            prob_voxels = label_probability / label_size
112
            if multichannel:
113
                probability_map[label] = prob_voxels * mask
114
            else:
115
                probability_map[mask] = prob_voxels
116
        if multichannel:
117
            probability_map = probability_map.sum(dim=0, keepdim=True)
118
        return probability_map
119