Passed
Pull Request — master (#214)
by Fernando
01:47
created

torchio.data.sampler.weighted   A

Complexity

Total Complexity 20

Size/Duplication

Total Lines 264
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 124
dl 0
loc 264
rs 10
c 0
b 0
f 0
wmc 20

9 Methods

Rating   Name   Duplication   Size   Complexity  
A WeightedSampler.__init__() 0 9 1
A WeightedSampler.process_probability_map() 0 17 2
A WeightedSampler.__call__() 0 22 5
A WeightedSampler.get_cumulative_distribution_function() 0 26 1
A WeightedSampler.get_probability_map() 0 16 3
A WeightedSampler.clear_probability_borders() 0 43 4
A WeightedSampler.sample_probability_map() 0 44 2
A WeightedSampler.get_random_index_ini() 0 12 1
A WeightedSampler.extract_patch() 0 12 1
1
from typing import Optional, Tuple, Generator
2
3
import numpy as np
4
5
import torch
6
7
from ...torchio import TypePatchSize
8
from ..subject import Subject
9
from .sampler import RandomSampler
10
11
12
13
class WeightedSampler(RandomSampler):
14
    r"""Randomly extract patches from a volume given a probability map.
15
16
    The probability of sampling a patch centered on a specific voxel is the
17
    value of that voxel in the probability map. The probabilities need not be
18
    normalized. For example, voxels can have values 0, 1 and 5. Voxels with
19
    value 0 will never be at the center of a patch. Voxels with value 5 will
20
    have 5 times more chance of being at the center of a patch that voxels
21
    with a value of 1.
22
23
    Args:
24
        sample: Sample generated by a
25
            :py:class:`~torchio.data.dataset.ImagesDataset`, from which image
26
            patches will be extracted.
27
        patch_size: See :py:class:`~torchio.data.PatchSampler`.
28
        probability_map: Name of the image in the sample that will be used
29
            as a probability map.
30
31
    Raises:
32
        RuntimeError: If the probability map is empty.
33
34
    Example:
35
        >>> import torchio
36
        >>> subject = torchio.Subject(
37
        ...     t1=torchio.Image('t1_mri.nii.gz', type=torchio.INTENSITY),
38
        ...     sampling_map=torchio.Image('sampling.nii.gz', type=torchio.SAMPLING_MAP),
39
        ... )
40
        >>> sample = torchio.ImagesDataset([subject])[0]
41
        >>> patch_size = 64
42
        >>> sampler = torchio.data.WeightedSampler(patch_size, probability_map='sampling_map')
43
        >>> for patch in sampler(sample):
44
        ...     print(patch['index_ini'])
45
46
    .. note:: The index of the center of a patch with even size :math:`s` is
47
        arbitrarily set to :math:`s/2`. This is an implementation detail that
48
        will typically not make any difference in practice.
49
50
    .. note:: Values of the probability map near the border will be set to 0 as
51
        the center of the patch cannot be at the border (unless the patch has
52
        size 1 or 2 along that axis).
53
54
    """
55
    def __init__(
56
            self,
57
            patch_size: TypePatchSize,
58
            probability_map: Optional[str] = None,
59
            ):
60
        super().__init__(patch_size)
61
        self.probability_map_name = probability_map
62
        self.cdf = None
63
        self.sort_indices = None
64
65
    def __call__(
66
            self,
67
            sample: Subject,
68
            num_patches: Optional[int] = None,
69
            ) -> Generator[Subject, None, None]:
70
        sample.check_consistent_shape()
71
        if np.any(self.patch_size > sample.spatial_shape):
72
            message = (
73
                f'Patch size {tuple(self.patch_size)} cannot be'
74
                f' larger than image size {tuple(sample.spatial_shape)}'
75
            )
76
            raise RuntimeError(message)
77
        probability_map = self.get_probability_map(sample)
78
        probability_map = self.process_probability_map(probability_map)
79
        cdf, sort_indices = self.get_cumulative_distribution_function(
80
            probability_map)
81
82
        patches_left = num_patches if num_patches is not None else True
83
        while patches_left:
84
            yield self.extract_patch(sample, probability_map, cdf, sort_indices)
85
            if num_patches is not None:
86
                patches_left -= 1
87
88
    def get_probability_map(self, sample: Subject) -> torch.Tensor:
89
        if self.probability_map_name in sample:
90
            data = sample[self.probability_map_name].data
91
        else:
92
            message = (
93
                f'Image "{self.probability_map_name}"'
94
                f' not found in subject sample: {sample}'
95
            )
96
            raise KeyError(message)
97
        if torch.any(data < 0):
98
            message = (
99
                'Negative values found'
100
                f' in probability map "{self.probability_map_name}"'
101
            )
102
            raise ValueError(message)
103
        return data
104
105
    def process_probability_map(
106
            self,
107
            probability_map: torch.Tensor,
108
            ) -> np.ndarray:
109
        # Using float32 can create cdf with maximum very far from 1, e.g. 0.92!
110
        data = probability_map[0].numpy().astype(np.float64)
111
        assert data.ndim == 3
112
        self.clear_probability_borders(data, self.patch_size)
113
        total = data.sum()
114
        if total == 0:
115
            message = (
116
                'Empty probability map found'
117
                f' ({self.probability_map_name})'
118
            )
119
            raise RuntimeError(message)
120
        data /= total  # normalize probabilities
121
        return data
122
123
    @staticmethod
124
    def clear_probability_borders(
125
            probability_map: np.ndarray,
126
            patch_size: TypePatchSize,
127
            ) -> None:
128
        # Set probability to 0 on voxels that wouldn't possibly be sampled given
129
        # the current patch size
130
        # We will arbitrarily define the center of an array with even length
131
        # using the // Python operator
132
        # For example, the center of an array (3, 4) will be on (1, 2)
133
        #
134
        #   Patch         center
135
        #  . . . .        . . . .
136
        #  . . . .   ->   . . x .
137
        #  . . . .        . . . .
138
        #
139
        #
140
        #    Prob. map      After preprocessing
141
        #
142
        #  x x x x x x x       . . . . . . .
143
        #  x x x x x x x       . . x x x x .
144
        #  x x x x x x x  -->  . . x x x x .
145
        #  x x x x x x x  -->  . . x x x x .
146
        #  x x x x x x x       . . x x x x .
147
        #  x x x x x x x       . . . . . . .
148
        #
149
        # The dots represent removed probabilities, x mark possible locations
150
        crop_ini = patch_size // 2
151
        crop_fin = (patch_size - 1) // 2
152
        crop_i, crop_j, crop_k = crop_ini
153
        probability_map[:crop_i, :, :] = 0
154
        probability_map[:, :crop_j, :] = 0
155
        probability_map[:, :, :crop_k] = 0
156
157
        # The call tolist() is very important. Using np.uint16 as negative index
158
        # will not work because e.g. -np.uint16(2) == 65534
159
        crop_i, crop_j, crop_k = crop_fin.tolist()
160
        if crop_i:
161
            probability_map[-crop_i:, :, :] = 0
162
        if crop_j:
163
            probability_map[:, -crop_j:, :] = 0
164
        if crop_k:
165
            probability_map[:, :, -crop_k:] = 0
166
167
    @staticmethod
168
    def get_cumulative_distribution_function(
169
            probability_map: np.ndarray,
170
            ) -> Tuple[np.ndarray, np.ndarray]:
171
        """Return the CDF of a probability map.
172
173
        The cumulative distribution function (CDF) is computed as follows:
174
175
        1. Flatten probability map
176
        2. Compute sorting indices
177
        3. Sort flattened map
178
        4. Compute cumulative sum
179
180
        For example,
181
        if the probability map is [0.0, 0.0, 0.1, 0.2, 0.5, 0.1, 0.1, 0.0],
182
        the sorting indices are [0, 1, 7, 2, 5, 6, 3, 4],
183
        the sorted map is [0.0, 0.0, 0.0, 0.1, 0.1, 0.1, 0.2, 0.5],
184
        and the CDF is [0.0, 0.0, 0.0, 0.1, 0.2, 0.3, 0.5, 1.0].
185
        """
186
        flat_map = probability_map.flatten()
187
        flat_map_normalized = flat_map / flat_map.sum()
188
        # Get the sorting indices to that we can invert the sorting later on
189
        sort_indices = np.argsort(flat_map_normalized)
190
        flat_map_normalized_sorted = flat_map_normalized[sort_indices]
191
        cdf = np.cumsum(flat_map_normalized_sorted)
192
        return cdf, sort_indices
193
194
    def extract_patch(
195
            self,
196
            sample: Subject,
197
            probability_map: np.ndarray,
198
            cdf: np.ndarray,
199
            sort_indices: np.ndarray,
200
            ) -> Subject:
201
        index_ini = self.get_random_index_ini(probability_map, cdf, sort_indices)
202
        index_fin = index_ini + self.patch_size
203
        cropped_sample = sample.crop(index_ini, index_fin)
204
        cropped_sample['index_ini'] = index_ini.astype(int)
205
        return cropped_sample
206
207
    def get_random_index_ini(
208
            self,
209
            probability_map: np.ndarray,
210
            cdf: np.ndarray,
211
            sort_indices: np.ndarray,
212
            ) -> np.ndarray:
213
        center = self.sample_probability_map(probability_map, cdf, sort_indices)
214
        assert np.all(center >= 0)
215
        # See self.clear_probability_borders
216
        index_ini = center - self.patch_size // 2
217
        assert np.all(index_ini >= 0)
218
        return index_ini
219
220
    def sample_probability_map(
221
            self,
222
            probability_map: np.ndarray,
223
            cdf: np.ndarray,
224
            sort_indices: np.ndarray,
225
            ) -> np.ndarray:
226
        """Inverse transform sampling.
227
228
        Example:
229
            >>> probability_map = np.array(
230
            ...    ((0,0,1,1,5,2,1,1,0),
231
            ...     (2,2,2,2,2,2,2,2,2)))
232
            >>> probability_map
233
            array([[0, 0, 1, 1, 5, 2, 1, 1, 0],
234
                   [2, 2, 2, 2, 2, 2, 2, 2, 2]])
235
            >>> histogram = np.zeros_like(probability_map)
236
            >>> for _ in range(100000):
237
            ...     histogram[sample_probability_map(probability_map)] += 1
238
            ...
239
            >>> histogram
240
            array([[    0,     0,  3479,  3478, 17121,  7023,  3355,  3378,     0],
241
                   [ 6808,  6804,  6942,  6809,  6946,  6988,  7002,  6826,  7041]])
242
243
        """
244
        # Get first value larger than random number
245
        random_number = torch.rand(1).item()
246
        # If probability map is float32, cdf.max() can be far from 1, e.g. 0.92
247
        if random_number > cdf.max():
248
            cdf_index = -1
249
        else:  # proceed as usual
250
            cdf_index = np.argmax(random_number < cdf)
251
252
        random_location_index = sort_indices[cdf_index]
253
        center = np.unravel_index(
254
            random_location_index,
255
            probability_map.shape
256
        )
257
258
        i, j, k = center
259
        probability = probability_map[i, j, k]
260
        assert probability > 0
261
262
        center = np.array(center).astype(int)
263
        return center
264