Passed
Push — master ( c9c9a5...7d9f03 )
by Fernando
01:40
created

WeightedSampler.__call__()   A

Complexity

Conditions 5

Size

Total Lines 22
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 18
nop 3
dl 0
loc 22
rs 9.0333
c 0
b 0
f 0
1
from typing import Optional, Tuple, Generator
2
3
import torch
4
import numpy as np
5
6
from ...constants import MIN_FLOAT_32
7
from ...typing import TypePatchSize
8
from ..image import Image
9
from ..subject import Subject
10
from .sampler import RandomSampler
11
12
13
class WeightedSampler(RandomSampler):
14
    r"""Randomly extract patches from a volume given a probability map.
15
16
    The probability of sampling a patch centered on a specific voxel is the
17
    value of that voxel in the probability map. The probabilities need not be
18
    normalized. For example, voxels can have values 0, 1 and 5. Voxels with
19
    value 0 will never be at the center of a patch. Voxels with value 5 will
20
    have 5 times more chance of being at the center of a patch that voxels
21
    with a value of 1.
22
23
    Args:
24
        patch_size: See :class:`~torchio.data.PatchSampler`.
25
        probability_map: Name of the image in the input subject that will be
26
            used as a sampling probability map.
27
28
    Raises:
29
        RuntimeError: If the probability map is empty.
30
31
    Example:
32
        >>> import torchio as tio
33
        >>> subject = tio.Subject(
34
        ...     t1=tio.ScalarImage('t1_mri.nii.gz'),
35
        ...     sampling_map=tio.Image('sampling.nii.gz', type=tio.SAMPLING_MAP),
36
        ... )
37
        >>> patch_size = 64
38
        >>> sampler = tio.data.WeightedSampler(patch_size, 'sampling_map')
39
        >>> for patch in sampler(subject):
40
        ...     print(patch['index_ini'])
41
42
    .. note:: The index of the center of a patch with even size :math:`s` is
43
        arbitrarily set to :math:`s/2`. This is an implementation detail that
44
        will typically not make any difference in practice.
45
46
    .. note:: Values of the probability map near the border will be set to 0 as
47
        the center of the patch cannot be at the border (unless the patch has
48
        size 1 or 2 along that axis).
49
50
    """  # noqa: E501
51
    def __init__(
52
            self,
53
            patch_size: TypePatchSize,
54
            probability_map: str,
55
            ):
56
        super().__init__(patch_size)
57
        self.probability_map_name = probability_map
58
        self.cdf = None
59
60
    def _generate_patches(
61
            self,
62
            subject: Subject,
63
            num_patches: Optional[int] = None,
64
            ) -> Generator[Subject, None, None]:
65
        probability_map = self.get_probability_map(subject)
66
        probability_map = self.process_probability_map(
67
            probability_map, subject)
68
        cdf = self.get_cumulative_distribution_function(probability_map)
69
70
        patches_left = num_patches if num_patches is not None else True
71
        while patches_left:
72
            yield self.extract_patch(subject, probability_map, cdf)
73
            if num_patches is not None:
74
                patches_left -= 1
75
76
    def get_probability_map_image(self, subject: Subject) -> Image:
77
        if self.probability_map_name in subject:
78
            return subject[self.probability_map_name]
79
        else:
80
            message = (
81
                f'Image "{self.probability_map_name}"'
82
                f' not found in subject: {subject}'
83
            )
84
            raise KeyError(message)
85
86
    def get_probability_map(self, subject: Subject) -> torch.Tensor:
87
        data = self.get_probability_map_image(subject).data
88
        if torch.any(data < 0):
89
            message = (
90
                'Negative values found'
91
                f' in probability map "{self.probability_map_name}"'
92
            )
93
            raise ValueError(message)
94
        return data
95
96
    def process_probability_map(
97
            self,
98
            probability_map: torch.Tensor,
99
            subject: Subject,
100
            ) -> np.ndarray:
101
        # Using float32 can create cdf with maximum very far from 1, e.g. 0.92!
102
        data = probability_map[0].numpy().astype(np.float64)
103
        assert data.ndim == 3
104
        self.clear_probability_borders(data, self.patch_size)
105
        total = data.sum()
106
        if total == 0:
107
            half_patch_size = tuple(n // 2 for n in self.patch_size)
108
            message = (
109
                'Empty probability map found:'
110
                f' {self.get_probability_map_image(subject).path}'
111
                '\nVoxels with positive probability might be near the image'
112
                ' border.\nIf you suspect that this is the case, try adding a'
113
                ' padding transform\nwith half the patch size:'
114
                f' torchio.Pad({half_patch_size})'
115
            )
116
            raise RuntimeError(message)
117
        data /= total  # normalize probabilities
118
        return data
119
120
    @staticmethod
121
    def clear_probability_borders(
122
            probability_map: np.ndarray,
123
            patch_size: TypePatchSize,
124
            ) -> None:
125
        # Set probability to 0 on voxels that wouldn't possibly be sampled
126
        # given the current patch size
127
        # We will arbitrarily define the center of an array with even length
128
        # using the // Python operator
129
        # For example, the center of an array (3, 4) will be on (1, 2)
130
        #
131
        #   Patch         center
132
        #  . . . .        . . . .
133
        #  . . . .   ->   . . x .
134
        #  . . . .        . . . .
135
        #
136
        #
137
        #    Prob. map      After preprocessing
138
        #
139
        #  x x x x x x x       . . . . . . .
140
        #  x x x x x x x       . . x x x x .
141
        #  x x x x x x x  -->  . . x x x x .
142
        #  x x x x x x x  -->  . . x x x x .
143
        #  x x x x x x x       . . x x x x .
144
        #  x x x x x x x       . . . . . . .
145
        #
146
        # The dots represent removed probabilities, x mark possible locations
147
        crop_ini = patch_size // 2
148
        crop_fin = (patch_size - 1) // 2
149
        crop_i, crop_j, crop_k = crop_ini
150
        probability_map[:crop_i, :, :] = 0
151
        probability_map[:, :crop_j, :] = 0
152
        probability_map[:, :, :crop_k] = 0
153
154
        # The call tolist() is very important. Using np.uint16 as negative
155
        # index will not work because e.g. -np.uint16(2) == 65534
156
        crop_i, crop_j, crop_k = crop_fin.tolist()
157
        if crop_i:
158
            probability_map[-crop_i:, :, :] = 0
159
        if crop_j:
160
            probability_map[:, -crop_j:, :] = 0
161
        if crop_k:
162
            probability_map[:, :, -crop_k:] = 0
163
164
    @staticmethod
165
    def get_cumulative_distribution_function(
166
            probability_map: np.ndarray,
167
            ) -> Tuple[np.ndarray, np.ndarray]:
168
        """Return the cumulative distribution function of a probability map."""
169
        flat_map = probability_map.flatten()
170
        flat_map_normalized = flat_map / flat_map.sum()
171
        cdf = np.cumsum(flat_map_normalized)
172
        return cdf
173
174
    def extract_patch(
175
            self,
176
            subject: Subject,
177
            probability_map: np.ndarray,
178
            cdf: np.ndarray
179
            ) -> Subject:
180
        index_ini = self.get_random_index_ini(probability_map, cdf)
181
        cropped_subject = self.crop(subject, index_ini, self.patch_size)
182
        cropped_subject['index_ini'] = index_ini.astype(int)
183
        return cropped_subject
184
185
    def get_random_index_ini(
186
            self,
187
            probability_map: np.ndarray,
188
            cdf: np.ndarray
189
            ) -> np.ndarray:
190
        center = self.sample_probability_map(probability_map, cdf)
191
        assert np.all(center >= 0)
192
        # See self.clear_probability_borders
193
        index_ini = center - self.patch_size // 2
194
        assert np.all(index_ini >= 0)
195
        return index_ini
196
197
    @classmethod
198
    def sample_probability_map(
199
            cls,
200
            probability_map: np.ndarray,
201
            cdf: np.ndarray
202
            ) -> np.ndarray:
203
        """Inverse transform sampling.
204
205
        Example:
206
            >>> probability_map = np.array(
207
            ...    ((0,0,1,1,5,2,1,1,0),
208
            ...     (2,2,2,2,2,2,2,2,2)))
209
            >>> probability_map
210
            array([[0, 0, 1, 1, 5, 2, 1, 1, 0],
211
                   [2, 2, 2, 2, 2, 2, 2, 2, 2]])
212
            >>> histogram = np.zeros_like(probability_map)
213
            >>> for _ in range(100000):
214
            ...     histogram[WeightedSampler.sample_probability_map(probability_map, cdf)] += 1  # doctest:+SKIP
215
            ...
216
            >>> histogram  # doctest:+SKIP
217
            array([[    0,     0,  3479,  3478, 17121,  7023,  3355,  3378,     0],
218
                   [ 6808,  6804,  6942,  6809,  6946,  6988,  7002,  6826,  7041]])
219
220
        """  # noqa: E501
221
        # Get first value larger than random number ensuring the random number
222
        # is not exactly 0 (see https://github.com/fepegar/torchio/issues/510)
223
        random_number = max(MIN_FLOAT_32, torch.rand(1).item())
224
225
        # Accumulated floating point errors might make cdf.max() less than 1
226
        if random_number > cdf.max():
227
            random_location_index = -1
228
        else:  # proceed as usual
229
            random_location_index = np.searchsorted(cdf, random_number)
230
231
        center = np.unravel_index(
232
            random_location_index,
233
            probability_map.shape
234
        )
235
236
        probability = probability_map[center]
237
        if probability <= 0:
238
            message = (
239
                'Error retrieving probability in weighted sampler.'
240
                ' Please report this issue at'
241
                ' https://github.com/fepegar/torchio/issues/new?labels=bug&template=bug_report.md'  # noqa: E501
242
            )
243
            raise RuntimeError(message)
244
245
        return np.array(center)
246