Passed
Pull Request — master (#175)
by Fernando
58s
created

WeightedSampler.process_probability_map()   A

Complexity

Conditions 2

Size

Total Lines 13
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 11
nop 2
dl 0
loc 13
rs 9.85
c 0
b 0
f 0
1
from typing import Optional, Tuple, Generator
2
3
import numpy as np
4
5
import torch
6
7
from ...torchio import TypePatchSize
8
from ..subject import Subject
9
from .sampler import PatchSampler
10
11
12
13
class WeightedSampler(PatchSampler):
14
    r"""Randomly extract patches from a volume given a probability map.
15
16
    The probability of sampling a patch centered on a specific voxel is the
17
    value of that voxel in the probability map. The probabilities need not be
18
    normalized. For example, voxels can have values 0, 1 and 5. Voxels with
19
    value 0 will never be at the center of a patch. Voxels with value 5 will
20
    have 5 times more chance of being at the center of a patch that voxels
21
    with a value of 1.
22
23
    Args:
24
        sample: Sample generated by a
25
            :py:class:`~torchio.data.dataset.ImagesDataset`, from which image
26
            patches will be extracted.
27
        patch_size: See :py:class:`~torchio.data.PatchSampler`.
28
        probability_map: Name of the image in the sample that will be used
29
            as a probability map.
30
31
    Example:
32
        >>> import torchio
33
        >>> subject = torchio.Subject(
34
        ...     t1=torchio.Image('t1_mri.nii.gz', type=torchio.INTENSITY),
35
        ...     sampling_map=torchio.Image('sampling.nii.gz', type=torchio.SAMPLING_MAP),
36
        ... )
37
        >>> sample = torchio.ImagesDataset([subject])[0]
38
        >>> patch_size = 64
39
        >>> sampler = torchio.data.WeightedSampler(patch_size, probability_map='sampling_map')
40
        >>> for patch in sampler(sample):
41
        ...     print(patch['index_ini'])
42
43
    .. note:: The index of the center of a patch with even size :math:`s` is
44
        arbitrarily set to :math:`s/2`. This is an implementation detail that
45
        will typically not make any difference in practice.
46
47
    .. note:: Values of the probability map near the border will be set to 0 as
48
        the center of the patch cannot be at the border (unless the patch has
49
        size 1 or 2 along that axis).
50
51
    """
52
    def __init__(
53
            self,
54
            patch_size: TypePatchSize,
55
            probability_map: Optional[str] = None,
56
            ):
57
        super().__init__(patch_size)
58
        self.probability_map_name = probability_map
59
        self.cdf = None
60
        self.sort_indices = None
61
62
    def __call__(
63
            self,
64
            sample: Subject,
65
            num_patches: Optional[int] = None,
66
            ) -> Generator[Subject, None, None]:
67
        sample.check_consistent_shape()
68
        if np.any(self.patch_size > sample.spatial_shape):
69
            message = (
70
                f'Patch size {tuple(self.patch_size)} cannot be'
71
                f' larger than image size {tuple(sample.spatial_shape)}'
72
            )
73
            raise RuntimeError(message)
74
        probability_map = self.get_probability_map(sample)
75
        probability_map = self.process_probability_map(probability_map)
76
        cdf, sort_indices = self.get_cumulative_distribution_function(
77
            probability_map)
78
79
        patches_left = num_patches if num_patches is not None else True
80
        while patches_left:
81
            yield self.extract_patch(sample, probability_map, cdf, sort_indices)
82
            if num_patches is not None:
83
                patches_left -= 1
84
85
    def get_probability_map(self, sample: Subject) -> torch.Tensor:
86
        if self.probability_map_name in sample:
87
            data = sample[self.probability_map_name].data
88
        else:
89
            message = (
90
                f'Image "{self.probability_map_name}"'
91
                f' not found in subject sample: {sample}'
92
            )
93
            raise KeyError(message)
94
        if torch.any(data < 0):
95
            message = (
96
                'Negative values found'
97
                f' in probability map "{self.probability_map_name}"'
98
            )
99
            raise ValueError(message)
100
        return data
101
102
    def process_probability_map(
103
            self,
104
            probability_map: torch.Tensor,
105
            ) -> np.ndarray:
106
        # Using float32 can create cdf with maximum very far from 1, e.g. 0.92!
107
        data = probability_map[0].numpy().astype(np.float64)
108
        assert data.ndim == 3
109
        if data.sum() == 0:  # although it should not be empty
110
            data += 1  # make uniform
111
        data /= data.sum()  # normalize probabilities
112
        self.clear_probability_borders(data, self.patch_size)
113
        assert data.sum() > 0
114
        return data
115
116
    @staticmethod
117
    def clear_probability_borders(
118
            probability_map: np.ndarray,
119
            patch_size: TypePatchSize,
120
            ) -> None:
121
        # Set probability to 0 on voxels that wouldn't possibly be sampled given
122
        # the current patch size
123
        # We will arbitrarily define the center of an array with even length
124
        # using the // Python operator
125
        # For example, the center of an array (3, 4) will be on (1, 2)
126
        #
127
        #   Patch         center
128
        #  . . . .        . . . .
129
        #  . . . .   ->   . . x .
130
        #  . . . .        . . . .
131
        #
132
        #
133
        #    Prob. map      After preprocessing
134
        #
135
        #  x x x x x x x       . . . . . . .
136
        #  x x x x x x x       . . x x x x .
137
        #  x x x x x x x  -->  . . x x x x .
138
        #  x x x x x x x  -->  . . x x x x .
139
        #  x x x x x x x       . . x x x x .
140
        #  x x x x x x x       . . . . . . .
141
        #
142
        # The dots represent removed probabilities, x mark possible locations
143
        crop_ini = patch_size // 2
144
        crop_fin = (patch_size - 1) // 2
145
        crop_i, crop_j, crop_k = crop_ini
146
        probability_map[:crop_i, :, :] = 0
147
        probability_map[:, :crop_j, :] = 0
148
        probability_map[:, :, :crop_k] = 0
149
150
        # The call tolist() is very important. Using np.uint16 as negative index
151
        # will not work because e.g. -np.uint16(2) == 65534
152
        crop_i, crop_j, crop_k = crop_fin.tolist()
153
        if crop_i:
154
            probability_map[-crop_i:, :, :] = 0
155
        if crop_j:
156
            probability_map[:, -crop_j:, :] = 0
157
        if crop_k:
158
            probability_map[:, :, -crop_k:] = 0
159
160
    @staticmethod
161
    def get_cumulative_distribution_function(
162
            probability_map: np.ndarray,
163
            ) -> Tuple[np.ndarray, np.ndarray]:
164
        """Return the CDF of a probability map.
165
166
        The cumulative distribution function (CDF) is computed as follows:
167
168
        1. Flatten probability map
169
        2. Normalize it
170
        3. Compute sorting indices
171
        4. Sort flattened map
172
        5. Compute cumulative sum
173
174
        For example,
175
        if the probability map is [0, 0, 1, 2, 5, 1, 1, 0],
176
        the normalized version is [0.0, 0.0, 0.1, 0.2, 0.5, 0.1, 0.1, 0.0],
177
        the sorting indices are [0, 1, 7, 2, 5, 6, 3, 4],
178
        the sorted map is [0.0, 0.0, 0.0, 0.1, 0.1, 0.1, 0.2, 0.5],
179
        and the CDF is [0.0, 0.0, 0.0, 0.1, 0.2, 0.3, 0.5, 1.0].
180
        """
181
        flat_map = probability_map.flatten()
182
        flat_map_normalized = flat_map / flat_map.sum()
183
        # Get the sorting indices to that we can invert the sorting later on
184
        sort_indices = np.argsort(flat_map_normalized)
185
        flat_map_normalized_sorted = flat_map_normalized[sort_indices]
186
        cdf = np.cumsum(flat_map_normalized_sorted)
187
        return cdf, sort_indices
188
189
    def extract_patch(
190
            self,
191
            sample: Subject,
192
            probability_map: np.ndarray,
193
            cdf: np.ndarray,
194
            sort_indices: np.ndarray,
195
            ) -> Subject:
196
        index_ini = self.get_random_index_ini(probability_map, cdf, sort_indices)
197
        crop = self.get_crop_transform(
198
            sample,
199
            index_ini,
200
            self.patch_size,
201
        )
202
        cropped_sample = crop(sample)
203
        cropped_sample['index_ini'] = index_ini.astype(int)
204
        return cropped_sample
205
206
    def get_random_index_ini(
207
            self,
208
            probability_map: np.ndarray,
209
            cdf: np.ndarray,
210
            sort_indices: np.ndarray,
211
            ) -> np.ndarray:
212
        center = self.sample_probability_map(probability_map, cdf, sort_indices)
213
        assert np.all(center >= 0)
214
        # See self.clear_probability_borders
215
        index_ini = center - self.patch_size // 2
216
        assert np.all(index_ini >= 0)
217
        return index_ini
218
219
    def sample_probability_map(
220
            self,
221
            probability_map: np.ndarray,
222
            cdf: np.ndarray,
223
            sort_indices: np.ndarray,
224
            ) -> np.ndarray:
225
        """Inverse transform sampling.
226
227
        Example:
228
            >>> probability_map = np.array(
229
            ...    ((0,0,1,1,5,2,1,1,0),
230
            ...     (2,2,2,2,2,2,2,2,2)))
231
            >>> probability_map
232
            array([[0, 0, 1, 1, 5, 2, 1, 1, 0],
233
                   [2, 2, 2, 2, 2, 2, 2, 2, 2]])
234
            >>> histogram = np.zeros_like(probability_map)
235
            >>> for _ in range(100000):
236
            ...     histogram[sample_probability_map(probability_map)] += 1
237
            ...
238
            >>> histogram
239
            array([[    0,     0,  3479,  3478, 17121,  7023,  3355,  3378,     0],
240
                   [ 6808,  6804,  6942,  6809,  6946,  6988,  7002,  6826,  7041]])
241
242
        """
243
        # Get first value larger than random number
244
        random_number = torch.rand(1).item()
245
        # If probability map is float32, cdf.max() can be far from 1, e.g. 0.92
246
        if random_number > cdf.max():
247
            cdf_index = -1
248
        else:  # proceed as usual
249
            cdf_index = np.argmax(random_number < cdf)
250
251
        random_location_index = sort_indices[cdf_index]
252
        center = np.unravel_index(
253
            random_location_index,
254
            probability_map.shape
255
        )
256
257
        i, j, k = center
258
        probability = probability_map[i, j, k]
259
        assert probability > 0
260
261
        center = np.array(center).astype(int)
262
        return center
263