Passed
Pull Request — master (#175)
by Fernando
01:04
created

WeightedSampler.sample_probability_map()   A

Complexity

Conditions 2

Size

Total Lines 39
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 14
nop 4
dl 0
loc 39
rs 9.7
c 0
b 0
f 0
1
import copy
2
from typing import Union, Sequence, Generator, Optional
3
4
import numpy as np
5
6
import torch
7
8
from ...torchio import DATA
9
from ..subject import Subject
10
from .sampler import PatchSampler
11
12
13
14
class WeightedSampler(PatchSampler):
15
    r"""Randomly extract patches from a volume given a probability map.
16
17
    The probability of sampling a patch centered on a specific voxel is the
18
    value of that voxel in the probability map. The probabilities need not be
19
    normalized. For example, voxels can have values 0, 1 and 5. Voxels with
20
    value 0 will never be at the center of a patch. Voxels with value 5 will
21
    have 5 times more chance of being at the center of a patch that voxels
22
    with a value of 1.
23
24
    Args:
25
        sample: Sample generated by a
26
            :py:class:`~torchio.data.dataset.ImagesDataset`, from which image
27
            patches will be extracted.
28
        patch_size: See :py:class:`~torchio.data.PatchSampler`.
29
        probability_map: Name of the image in the sample that will be used
30
            as a probability map.
31
32
    Example:
33
        >>> import torchio
34
        >>> subject = torchio.Subject(
35
        ...     t1=torchio.Image('t1_mri.nii.gz', type=torchio.INTENSITY),
36
        ...     sampling_map=torchio.Image('sampling.nii.gz', type=torchio.SAMPLING_MAP),
37
        ... )
38
        >>> sample = torchio.ImagesDataset([subject])[0]
39
        >>> patch_size = 64
40
        >>> sampler = torchio.data.WeightedSampler(patch_size, probability_map='sampling_map')
41
        >>> for patch in sampler(sample):
42
        ...     print(patch['index_ini'])
43
44
    .. note:: The index of the center of a patch with even size :math:`s` is
45
        arbitrarily set to :math:`s/2`. This is an implementation detail that
46
        will typically not make any difference in practice.
47
48
    .. note:: Values of the probability map near the border will be set to 0 as
49
        the center of the patch cannot be at the border (unless the patch has
50
        size 1 or 2 along that axis).
51
52
    """
53
    def __init__(
54
            self,
55
            patch_size: Union[int, Sequence[int]],
56
            probability_map: Optional[str] = None,
57
            ):
58
        super().__init__(patch_size)
59
        self.probability_map_name = probability_map
60
        self.cdf = None
61
        self.sort_indices = None
62
63
    def __call__(self, sample, num_patches=None):
64
        sample.check_consistent_shape()
65
        if np.any(self.patch_size > sample.spatial_shape):
66
            message = (
67
                f'Patch size {tuple(self.patch_size)} cannot be'
68
                f' larger than image size {tuple(sample.spatial_shape)}'
69
            )
70
            raise RuntimeError(message)
71
        probability_map = self.get_probability_map(sample)
72
        probability_map = self.process_probability_map(probability_map)
73
        cdf, sort_indices = self.get_cumulative_distribution_function(
74
            probability_map)
75
76
        patches_left = num_patches if num_patches is not None else True
77
        while patches_left:
78
            yield self.extract_patch(sample, probability_map, cdf, sort_indices)
79
            if num_patches is not None:
80
                patches_left -= 1
81
82
    def get_probability_map(self, sample):
83
        if self.probability_map_name in sample:
84
            data = sample[self.probability_map_name].data
85
        else:
86
            message = (
87
                f'Image "{self.probability_map_name}"'
88
                f' not found in subject sample: {sample}'
89
            )
90
            raise KeyError(message)
91
        if torch.any(data < 0):
92
            message = (
93
                'Negative values found'
94
                f' in probability map "{self.probability_map_name}"'
95
            )
96
            raise ValueError(message)
97
        return data
98
99
    def process_probability_map(self, probability_map):
100
        # Using float32 can create cdf with maximum very far from 1, e.g. 0.92!
101
        data = probability_map[0].numpy().astype(np.float64)
102
        assert data.ndim == 3
103
        if data.sum() == 0:  # although it should not be empty
104
            data += 1  # make uniform
105
        data /= data.sum()  # normalize probabilities
106
        self.clear_probability_borders(data, self.patch_size)
107
        assert data.sum() > 0
108
        return data
109
110
    @staticmethod
111
    def clear_probability_borders(probability_map, patch_size):
112
        # Set probability to 0 on voxels that wouldn't possibly be sampled given
113
        # the current patch size
114
        # We will arbitrarily define the center of an array with even length
115
        # using the // Python operator
116
        # For example, the center of an array (3, 4) will be on (1, 2)
117
        #
118
        #  . . . .        . . . .
119
        #  . . . .   ->   . . x .
120
        #  . . . .        . . . .
121
        #
122
        #  x x x x x x x      . . . . . . .
123
        #  x x x x x x x      . . x x x x .
124
        #  x x x x x x x  --> . . x x x x .
125
        #  x x x x x x x  --> . . x x x x .
126
        #  x x x x x x x      . . x x x x .
127
        #  x x x x x x x      . . . . . . .
128
        #
129
        # The dots represent removed probabilities, x mark possible locations
130
        crop_ini = patch_size // 2
131
        crop_fin = (patch_size - 1) // 2
132
        crop_i, crop_j, crop_k = crop_ini
133
        probability_map[:crop_i, :, :] = 0
134
        probability_map[:, :crop_j, :] = 0
135
        probability_map[:, :, :crop_k] = 0
136
137
        # The call tolist() is very important. Using np.uint16 as negative index
138
        # will not work because e.g. -np.uint16(2) == 65534
139
        crop_i, crop_j, crop_k = crop_fin.tolist()
140
        if crop_i:
141
            probability_map[-crop_i:, :, :] = 0
142
        if crop_j:
143
            probability_map[:, -crop_j:, :] = 0
144
        if crop_k:
145
            probability_map[:, :, -crop_k:] = 0
146
147
    @staticmethod
148
    def get_cumulative_distribution_function(probability_map):
149
        """Return the CDF of a probability map.
150
151
        The cumulative distribution function (CDF) is computed as follows:
152
153
        1. Flatten probability map
154
        2. Normalize it
155
        3. Compute sorting indices
156
        4. Sort flattened map
157
        5. Compute cumulative sum
158
159
        For example,
160
        if the probability map is [0, 0, 1, 2, 5, 1, 1, 0],
161
        the normalized version is [0.0, 0.0, 0.1, 0.2, 0.5, 0.1, 0.1, 0.0],
162
        the sorting indices are [0, 1, 7, 2, 5, 6, 3, 4],
163
        the sorted map is [0.0, 0.0, 0.0, 0.1, 0.1, 0.1, 0.2, 0.5],
164
        and the CDF is [0.0, 0.0, 0.0, 0.1, 0.2, 0.3, 0.5, 1.0].
165
        """
166
        flat_map = probability_map.flatten()
167
        flat_map_normalized = flat_map / flat_map.sum()
168
        # Get the sorting indices to that we can invert the sorting later on
169
        sort_indices = np.argsort(flat_map_normalized)
170
        flat_map_normalized_sorted = flat_map_normalized[sort_indices]
171
        cdf = np.cumsum(flat_map_normalized_sorted)
172
        return cdf, sort_indices
173
174
    def extract_patch(self, sample, probability_map, cdf, sort_indices) -> Subject:
175
        index_ini = self.get_random_index_ini(probability_map, cdf, sort_indices)
176
        crop = self.get_crop_transform(
177
            sample,
178
            index_ini,
179
            self.patch_size,
180
        )
181
        cropped_sample = crop(sample)
182
        cropped_sample['index_ini'] = index_ini.astype(int)
183
        return cropped_sample
184
185
    def get_random_index_ini(self, probability_map, cdf, sort_indices):
186
        center = self.sample_probability_map(probability_map, cdf, sort_indices)
187
        assert np.all(center >= 0)
188
        # See self.clear_probability_borders
189
        index_ini = center - self.patch_size // 2
190
        assert np.all(index_ini >= 0)
191
        return index_ini
192
193
    def sample_probability_map(self, probability_map, cdf, sort_indices):
194
        """Inverse transform sampling.
195
196
        Example:
197
            >>> probability_map = np.array(
198
            ...    ((0,0,1,1,5,2,1,1,0),
199
            ...     (2,2,2,2,2,2,2,2,2)))
200
            >>> probability_map
201
            array([[0, 0, 1, 1, 5, 2, 1, 1, 0],
202
                   [2, 2, 2, 2, 2, 2, 2, 2, 2]])
203
            >>> histogram = np.zeros_like(probability_map)
204
            >>> for _ in range(100000):
205
            ...     histogram[sample_probability_map(probability_map)] += 1
206
            ...
207
            >>> histogram
208
            array([[    0,     0,  3479,  3478, 17121,  7023,  3355,  3378,     0],
209
                   [ 6808,  6804,  6942,  6809,  6946,  6988,  7002,  6826,  7041]])
210
211
        """
212
        # Get first value larger than random number
213
        random_number = torch.rand(1).item()
214
        # If probability map is float32, cdf.max() can be far from 1, e.g. 0.92
215
        if random_number > cdf.max():
216
            cdf_index = -1
217
        else:  # proceed as usual
218
            cdf_index = np.argmax(random_number < cdf)
219
220
        random_location_index = sort_indices[cdf_index]
221
        center = np.unravel_index(
222
            random_location_index,
223
            probability_map.shape
224
        )
225
226
        i, j, k = center
227
        probability = probability_map[i, j, k]
228
        assert probability > 0
229
230
        center = np.array(center).astype(int)
231
        return center
232