Completed
Push — master ( d4206f...57ce9a )
by Fernando
01:43
created

WeightedSampler.get_probability_map_image()   A

Complexity

Conditions 2

Size

Total Lines 9
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 2
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
from typing import Optional, Tuple, Generator
2
3
import torch
4
import numpy as np
5
6
from ...typing import TypePatchSize
7
from ..image import Image
8
from ..subject import Subject
9
from .sampler import RandomSampler
10
11
12
class WeightedSampler(RandomSampler):
13
    r"""Randomly extract patches from a volume given a probability map.
14
15
    The probability of sampling a patch centered on a specific voxel is the
16
    value of that voxel in the probability map. The probabilities need not be
17
    normalized. For example, voxels can have values 0, 1 and 5. Voxels with
18
    value 0 will never be at the center of a patch. Voxels with value 5 will
19
    have 5 times more chance of being at the center of a patch that voxels
20
    with a value of 1.
21
22
    Args:
23
        patch_size: See :class:`~torchio.data.PatchSampler`.
24
        probability_map: Name of the image in the input subject that will be
25
            used as a sampling probability map.
26
27
    Raises:
28
        RuntimeError: If the probability map is empty.
29
30
    Example:
31
        >>> import torchio as tio
32
        >>> subject = tio.Subject(
33
        ...     t1=tio.ScalarImage('t1_mri.nii.gz'),
34
        ...     sampling_map=tio.Image('sampling.nii.gz', type=tio.SAMPLING_MAP),
35
        ... )
36
        >>> patch_size = 64
37
        >>> sampler = tio.data.WeightedSampler(patch_size, 'sampling_map')
38
        >>> for patch in sampler(subject):
39
        ...     print(patch['index_ini'])
40
41
    .. note:: The index of the center of a patch with even size :math:`s` is
42
        arbitrarily set to :math:`s/2`. This is an implementation detail that
43
        will typically not make any difference in practice.
44
45
    .. note:: Values of the probability map near the border will be set to 0 as
46
        the center of the patch cannot be at the border (unless the patch has
47
        size 1 or 2 along that axis).
48
49
    """
50
    def __init__(
51
            self,
52
            patch_size: TypePatchSize,
53
            probability_map: str,
54
            ):
55
        super().__init__(patch_size)
56
        self.probability_map_name = probability_map
57
        self.cdf = None
58
59
    def __call__(
60
            self,
61
            subject: Subject,
62
            num_patches: Optional[int] = None,
63
            ) -> Generator[Subject, None, None]:
64
        subject.check_consistent_space()
65
        if np.any(self.patch_size > subject.spatial_shape):
66
            message = (
67
                f'Patch size {tuple(self.patch_size)} cannot be'
68
                f' larger than image size {tuple(subject.spatial_shape)}'
69
            )
70
            raise RuntimeError(message)
71
        probability_map = self.get_probability_map(subject)
72
        probability_map = self.process_probability_map(probability_map, subject)
73
        cdf = self.get_cumulative_distribution_function(probability_map)
74
75
        patches_left = num_patches if num_patches is not None else True
76
        while patches_left:
77
            yield self.extract_patch(subject, probability_map, cdf)
78
            if num_patches is not None:
79
                patches_left -= 1
80
81
    def get_probability_map_image(self, subject: Subject) -> Image:
82
        if self.probability_map_name in subject:
83
            return subject[self.probability_map_name]
84
        else:
85
            message = (
86
                f'Image "{self.probability_map_name}"'
87
                f' not found in subject: {subject}'
88
            )
89
            raise KeyError(message)
90
91
    def get_probability_map(self, subject: Subject) -> torch.Tensor:
92
        data = self.get_probability_map_image(subject).data
93
        if torch.any(data < 0):
94
            message = (
95
                'Negative values found'
96
                f' in probability map "{self.probability_map_name}"'
97
            )
98
            raise ValueError(message)
99
        return data
100
101
    def process_probability_map(
102
            self,
103
            probability_map: torch.Tensor,
104
            subject: Subject,
105
            ) -> np.ndarray:
106
        # Using float32 can create cdf with maximum very far from 1, e.g. 0.92!
107
        data = probability_map[0].numpy().astype(np.float64)
108
        assert data.ndim == 3
109
        self.clear_probability_borders(data, self.patch_size)
110
        total = data.sum()
111
        if total == 0:
112
            message = (
113
                'Empty probability map found:'
114
                f' {self.get_probability_map_image(subject).path}'
115
            )
116
            raise RuntimeError(message)
117
        data /= total  # normalize probabilities
118
        return data
119
120
    @staticmethod
121
    def clear_probability_borders(
122
            probability_map: np.ndarray,
123
            patch_size: TypePatchSize,
124
            ) -> None:
125
        # Set probability to 0 on voxels that wouldn't possibly be sampled given
126
        # the current patch size
127
        # We will arbitrarily define the center of an array with even length
128
        # using the // Python operator
129
        # For example, the center of an array (3, 4) will be on (1, 2)
130
        #
131
        #   Patch         center
132
        #  . . . .        . . . .
133
        #  . . . .   ->   . . x .
134
        #  . . . .        . . . .
135
        #
136
        #
137
        #    Prob. map      After preprocessing
138
        #
139
        #  x x x x x x x       . . . . . . .
140
        #  x x x x x x x       . . x x x x .
141
        #  x x x x x x x  -->  . . x x x x .
142
        #  x x x x x x x  -->  . . x x x x .
143
        #  x x x x x x x       . . x x x x .
144
        #  x x x x x x x       . . . . . . .
145
        #
146
        # The dots represent removed probabilities, x mark possible locations
147
        crop_ini = patch_size // 2
148
        crop_fin = (patch_size - 1) // 2
149
        crop_i, crop_j, crop_k = crop_ini
150
        probability_map[:crop_i, :, :] = 0
151
        probability_map[:, :crop_j, :] = 0
152
        probability_map[:, :, :crop_k] = 0
153
154
        # The call tolist() is very important. Using np.uint16 as negative index
155
        # will not work because e.g. -np.uint16(2) == 65534
156
        crop_i, crop_j, crop_k = crop_fin.tolist()
157
        if crop_i:
158
            probability_map[-crop_i:, :, :] = 0
159
        if crop_j:
160
            probability_map[:, -crop_j:, :] = 0
161
        if crop_k:
162
            probability_map[:, :, -crop_k:] = 0
163
164
    @staticmethod
165
    def get_cumulative_distribution_function(
166
            probability_map: np.ndarray,
167
            ) -> Tuple[np.ndarray, np.ndarray]:
168
        """Return the cumulative distribution function of a probability map."""
169
        flat_map = probability_map.flatten()
170
        flat_map_normalized = flat_map / flat_map.sum()
171
        cdf = np.cumsum(flat_map_normalized)
172
        return cdf
173
174
    def extract_patch(
175
            self,
176
            subject: Subject,
177
            probability_map: np.ndarray,
178
            cdf: np.ndarray
179
            ) -> Subject:
180
        index_ini = self.get_random_index_ini(probability_map, cdf)
181
        cropped_subject = self.crop(subject, index_ini, self.patch_size)
182
        cropped_subject['index_ini'] = index_ini.astype(int)
183
        return cropped_subject
184
185
    def get_random_index_ini(
186
            self,
187
            probability_map: np.ndarray,
188
            cdf: np.ndarray
189
            ) -> np.ndarray:
190
        center = self.sample_probability_map(probability_map, cdf)
191
        assert np.all(center >= 0)
192
        # See self.clear_probability_borders
193
        index_ini = center - self.patch_size // 2
194
        assert np.all(index_ini >= 0)
195
        return index_ini
196
197
    @classmethod
198
    def sample_probability_map(
199
            cls,
200
            probability_map: np.ndarray,
201
            cdf: np.ndarray
202
            ) -> np.ndarray:
203
        """Inverse transform sampling.
204
205
        Example:
206
            >>> probability_map = np.array(
207
            ...    ((0,0,1,1,5,2,1,1,0),
208
            ...     (2,2,2,2,2,2,2,2,2)))
209
            >>> probability_map
210
            array([[0, 0, 1, 1, 5, 2, 1, 1, 0],
211
                   [2, 2, 2, 2, 2, 2, 2, 2, 2]])
212
            >>> histogram = np.zeros_like(probability_map)
213
            >>> for _ in range(100000):
214
            ...     histogram[WeightedSampler.sample_probability_map(probability_map, cdf)] += 1  # doctest:+SKIP
215
            ...
216
            >>> histogram  # doctest:+SKIP
217
            array([[    0,     0,  3479,  3478, 17121,  7023,  3355,  3378,     0],
218
                   [ 6808,  6804,  6942,  6809,  6946,  6988,  7002,  6826,  7041]])
219
220
        """
221
        # Get first value larger than random number
222
        random_number = torch.rand(1).item()
223
        # If probability map is float32, cdf.max() can be far from 1, e.g. 0.92
224
        if random_number > cdf.max():
225
            cdf_index = -1
226
        else:  # proceed as usual
227
            cdf_index = np.searchsorted(cdf, random_number)
228
229
        random_location_index = cdf_index
230
        center = np.unravel_index(
231
            random_location_index,
232
            probability_map.shape
233
        )
234
235
        i, j, k = center
236
        probability = probability_map[i, j, k]
237
        assert probability > 0
238
239
        center = np.array(center).astype(int)
240
        return center
241