Completed
Push — master ( 7f8818...d756b2 )
by Fernando
01:31
created

LabelSampler.get_probability_map_image()   B

Complexity

Conditions 6

Size

Total Lines 22
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 16
nop 2
dl 0
loc 22
rs 8.6666
c 0
b 0
f 0
1
from typing import Dict, Optional
2
3
import torch
4
5
from ...data.image import LabelMap
6
from ...data.subject import Subject
7
from ...typing import TypePatchSize
8
from ...constants import TYPE, LABEL
9
from .weighted import WeightedSampler
10
11
12
class LabelSampler(WeightedSampler):
13
    r"""Extract random patches with labeled voxels at their center.
14
15
    This sampler yields patches whose center value is greater than 0
16
    in the :attr:`label_name`.
17
18
    Args:
19
        patch_size: See :class:`~torchio.data.PatchSampler`.
20
        label_name: Name of the label image in the subject that will be used to
21
            generate the sampling probability map. If ``None``, the first image
22
            of type :attr:`torchio.LABEL` found in the subject subject will be
23
            used.
24
        label_probabilities: Dictionary containing the probability that each
25
            class will be sampled. Probabilities do not need to be normalized.
26
            For example, a value of ``{0: 0, 1: 2, 2: 1, 3: 1}`` will create a
27
            sampler whose patches centers will have 50% probability of being
28
            labeled as ``1``, 25% of being ``2`` and 25% of being ``3``.
29
            If ``None``, the label map is binarized and the value is set to
30
            ``{0: 0, 1: 1}``.
31
            If the input has multiple channels, a value of
32
            ``{0: 0, 1: 2, 2: 1, 3: 1}`` will create a
33
            sampler whose patches centers will have 50% probability of being
34
            taken from a non zero value of channel ``1``, 25% from channel
35
            ``2`` and 25% from channel ``3``.
36
37
    Example:
38
        >>> import torchio as tio
39
        >>> subject = tio.datasets.Colin27()
40
        >>> subject
41
        Colin27(Keys: ('t1', 'head', 'brain'); images: 3)
42
        >>> subject = tio.SubjectsDataset([subject])[0]
43
        >>> sampler = tio.data.LabelSampler(64, 'brain')
44
        >>> generator = sampler(subject)
45
        >>> for patch in generator:
46
        ...     print(patch.shape)
47
48
    If you want a specific number of patches from a volume, e.g. 10:
49
50
        >>> generator = sampler(subject, num_patches=10)
51
        >>> for patch in iterator:
52
        ...     print(patch.shape)
53
54
    """
55
    def __init__(
56
            self,
57
            patch_size: TypePatchSize,
58
            label_name: Optional[str] = None,
59
            label_probabilities: Optional[Dict[int, float]] = None,
60
            ):
61
        super().__init__(patch_size, probability_map=label_name)
62
        self.label_probabilities_dict = label_probabilities
63
64
    def get_probability_map_image(self, subject: Subject) -> LabelMap:
65
        if self.probability_map_name is None:
66
            for image in subject.get_images(intensity_only=False):
67
                if image[TYPE] == LABEL:
68
                    label_map = image
69
                    break
70
            else:
71
                images = subject.get_images(intensity_only=False)
72
                message = (
73
                    f'No label maps found in subject {subject} with image paths'
74
                    f' {[image.path for image in images]}'
75
                )
76
                raise RuntimeError(message)
77
        elif self.probability_map_name in subject:
78
            label_map = subject[self.probability_map_name]
79
        else:
80
            message = (
81
                f'Image "{self.probability_map_name}"'
82
                f' not found in subject subject: {subject}'
83
            )
84
            raise KeyError(message)
85
        return label_map
86
87
    def get_probability_map(self, subject: Subject) -> torch.Tensor:
88
        label_map_tensor = self.get_probability_map_image(subject).data
89
        label_map_tensor = label_map_tensor.float()
90
91
        if self.label_probabilities_dict is None:
92
            return label_map_tensor > 0
93
        probability_map = self.get_probabilities_from_label_map(
94
            label_map_tensor,
95
            self.label_probabilities_dict,
96
        )
97
        return probability_map
98
99
    @staticmethod
100
    def get_probabilities_from_label_map(
101
            label_map: torch.Tensor,
102
            label_probabilities_dict: Dict[int, float],
103
            ) -> torch.Tensor:
104
        """Create probability map according to label map probabilities."""
105
        multichannel = label_map.shape[0] > 1
106
        probability_map = torch.zeros_like(label_map)
107
        label_probs = torch.Tensor(list(label_probabilities_dict.values()))
108
        normalized_probs = label_probs / label_probs.sum()
109
        iterable = zip(label_probabilities_dict, normalized_probs)
110
        for label, label_probability in iterable:
111
            if multichannel:
112
                mask = label_map[label]
113
            else:
114
                mask = label_map == label
115
            label_size = mask.sum()
116
            if not label_size:
117
                continue
118
            prob_voxels = label_probability / label_size
119
            if multichannel:
120
                probability_map[label] = prob_voxels * mask
121
            else:
122
                probability_map[mask] = prob_voxels
123
        if multichannel:
124
            probability_map = probability_map.sum(dim=0, keepdim=True)
125
        return probability_map
126