Passed
Pull Request — master (#175)
by Fernando
01:15
created

WeightedSampler.__init__()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 8
nop 3
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
import copy
2
from typing import Union, Sequence, Generator, Optional
3
4
import numpy as np
5
6
import torch
7
8
from ...torchio import DATA
9
from ..subject import Subject
10
from .sampler import PatchSampler
11
12
13
14
class WeightedSampler(PatchSampler):
15
    r"""Randomly extract patches from a volume given a probability map.
16
17
    The probability of sampling a patch centered on a specific voxel is the
18
    value of that voxel in the probability map. The probabilities need not be
19
    normalized. For example, voxels can have values 0, 1 and 5. Voxels with
20
    value 0 will never be at the center of a patch. Voxels with value 5 will
21
    have 5 times more chance of being at the center of a patch that voxels
22
    with a value of 1.
23
24
    Args:
25
        sample: Sample generated by a
26
            :py:class:`~torchio.data.dataset.ImagesDataset`, from which image
27
            patches will be extracted.
28
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
29
            of size :math:`d \times h \times w`.
30
            If a single number :math:`n` is provided, :math:`d = h = w = n`.
31
        probability_map: Name of the image in the sample that will be used
32
            as a probability map. If ``None``, all voxels have the same
33
            probability of being at the patch center.
34
35
    .. note:: The index of the center of a patch with even size :math:`s` is
36
        arbitrarily set to :math:`s/2`. This is an implementation detail that
37
        will typically not make any difference in practice.
38
39
    .. note:: Values of the probability map near the border will be set to 0 as
40
        the center of the patch cannot be at the border (unless the patch has
41
        size 1 or 2 along that axis).
42
43
    """
44
    def __init__(
45
            self,
46
            patch_size: Union[int, Sequence[int]],
47
            probability_map: Optional[str] = None,
48
            ):
49
        super().__init__(patch_size)
50
        self.probability_map_name = probability_map
51
        self.cdf = None
52
        self.sort_indices = None
53
54
    def __call__(self, sample):
55
        sample.check_consistent_shape()
56
        if np.any(self.patch_size > sample.spatial_shape):
57
            message = (
58
                f'Patch size {tuple(self.patch_size)} cannot be'
59
                f' larger than image size {tuple(sample.spatial_shape)}'
60
            )
61
            raise RuntimeError(message)
62
        probability_map = self.get_probability_map(sample)
63
        probability_map = self.process_probability_map(probability_map)
64
        cdf, sort_indices = self.get_cumulative_distribution_function(
65
            probability_map)
66
67
        while True:
68
            yield self.extract_patch(sample, probability_map, cdf, sort_indices)
69
70
    def get_probability_map(self, sample):
71
        if self.probability_map_name in sample:
72
            data = sample[self.probability_map_name].data
73
        else:
74
            message = (
75
                f'Image "{self.probability_map_name}"'
76
                f' not found in subject sample: {sample}'
77
            )
78
            raise KeyError(message)
79
        if torch.any(data < 0):
80
            message = (
81
                'Negative values found'
82
                f' in probability map "{self.probability_map_name}"'
83
            )
84
            raise ValueError(message)
85
        return data
86
87
    def process_probability_map(self, probability_map):
88
        # Using float32 can create cdf with maximum very far from 1, e.g. 0.92!
89
        data = probability_map[0].numpy().astype(np.float64)
90
        assert data.ndim == 3
91
        if data.sum() == 0:  # although it should not be empty
92
            data += 1  # make uniform
93
        data /= data.sum()  # normalize probabilities
94
        self.clear_probability_borders(data, self.patch_size)
95
        assert data.sum() > 0
96
        return data
97
98
    @staticmethod
99
    def clear_probability_borders(probability_map, patch_size):
100
        # Set probability to 0 on voxels that wouldn't possibly be sampled given
101
        # the current patch size
102
        # We will arbitrarily define the center of an array with even length
103
        # using the // Python operator
104
        # For example, the center of an array (3, 4) will be on (1, 2)
105
        #
106
        #  . . . .        . . . .
107
        #  . . . .   ->   . . x .
108
        #  . . . .        . . . .
109
        #
110
        #  x x x x x x x      . . . . . . .
111
        #  x x x x x x x      . . x x x x .
112
        #  x x x x x x x  --> . . x x x x .
113
        #  x x x x x x x  --> . . x x x x .
114
        #  x x x x x x x      . . x x x x .
115
        #  x x x x x x x      . . . . . . .
116
        #
117
        # The dots represent removed probabilities, x mark possible locations
118
        crop_ini = patch_size // 2
119
        crop_fin = (patch_size - 1) // 2
120
        crop_i, crop_j, crop_k = crop_ini
121
        probability_map[:crop_i, :, :] = 0
122
        probability_map[:, :crop_j, :] = 0
123
        probability_map[:, :, :crop_k] = 0
124
125
        # The call tolist() is very important. Using np.uint16 as negative index
126
        # will not work because e.g. -np.uint16(2) == 65534
127
        crop_i, crop_j, crop_k = crop_fin.tolist()
128
        if crop_i:
129
            probability_map[-crop_i:, :, :] = 0
130
        if crop_j:
131
            probability_map[:, -crop_j:, :] = 0
132
        if crop_k:
133
            probability_map[:, :, -crop_k:] = 0
134
135
    @staticmethod
136
    def get_cumulative_distribution_function(probability_map):
137
        """Return the CDF of a probability map.
138
139
        The cumulative distribution function (CDF) is computed as follows:
140
141
        1. Flatten probability map
142
        2. Normalize it
143
        3. Compute sorting indices
144
        4. Sort flattened map
145
        5. Compute cumulative sum
146
147
        For example,
148
        if the probability map is [0, 0, 1, 2, 5, 1, 1, 0],
149
        the normalized version is [0.0, 0.0, 0.1, 0.2, 0.5, 0.1, 0.1, 0.0],
150
        the sorting indices are [0, 1, 7, 2, 5, 6, 3, 4],
151
        the sorted map is [0.0, 0.0, 0.0, 0.1, 0.1, 0.1, 0.2, 0.5],
152
        and the CDF is [0.0, 0.0, 0.0, 0.1, 0.2, 0.3, 0.5, 1.0].
153
        """
154
        flat_map = probability_map.flatten()
155
        flat_map_normalized = flat_map / flat_map.sum()
156
        # Get the sorting indices to that we can invert the sorting later on
157
        sort_indices = np.argsort(flat_map_normalized)
158
        flat_map_normalized_sorted = flat_map_normalized[sort_indices]
159
        cdf = np.cumsum(flat_map_normalized_sorted)
160
        return cdf, sort_indices
161
162
    def extract_patch(self, sample, probability_map, cdf, sort_indices) -> Subject:
163
        # TODO: replace with Crop transform
164
        index_ini = self.get_random_index_ini(probability_map, cdf, sort_indices)
165
        cropped_sample = self.copy_and_crop(sample, index_ini)
166
        assert cropped_sample.spatial_shape == tuple(self.patch_size)
167
        return cropped_sample
168
169
    def get_random_index_ini(self, probability_map, cdf, sort_indices):
170
        center = self.sample_probability_map(probability_map, cdf, sort_indices)
171
        assert np.all(center >= 0)
172
173
        # See self.clear_probability_borders
174
        index_ini = center - self.patch_size // 2
175
        assert np.all(index_ini >= 0)
176
        return index_ini
177
178
    def sample_probability_map(self, probability_map, cdf, sort_indices):
179
        """Inverse transform sampling.
180
181
        Example:
182
            >>> probability_map = np.array(
183
            ...    ((0,0,1,1,5,2,1,1,0),
184
            ...     (2,2,2,2,2,2,2,2,2)))
185
            >>> probability_map
186
            array([[0, 0, 1, 1, 5, 2, 1, 1, 0],
187
                   [2, 2, 2, 2, 2, 2, 2, 2, 2]])
188
            >>> histogram = np.zeros_like(probability_map)
189
            >>> for _ in range(100000):
190
            ...     histogram[sample_probability_map(probability_map)] += 1
191
            ...
192
            >>> histogram
193
            array([[    0,     0,  3479,  3478, 17121,  7023,  3355,  3378,     0],
194
                   [ 6808,  6804,  6942,  6809,  6946,  6988,  7002,  6826,  7041]])
195
196
        """
197
        # Get first value larger than random number
198
        random_number = torch.rand(1).item()
199
        # If probability map is float32, cdf.max() can be far from 1, e.g. 0.92
200
        if random_number > cdf.max():
201
            cdf_index = -1
202
        else:  # proceed as usual
203
            cdf_index = np.argmax(random_number < cdf)
204
205
        random_location_index = sort_indices[cdf_index]
206
        center = np.unravel_index(
207
            random_location_index,
208
            probability_map.shape
209
        )
210
211
        i, j, k = center
212
        probability = probability_map[i, j, k]
213
        assert probability > 0
214
215
        center = np.array(center).astype(int)
216
        return center
217
218 View Code Duplication
    def copy_and_crop(self, sample, index_ini: np.ndarray) -> dict:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
219
        index_fin = index_ini + self.patch_size
220
        cropped_sample = copy.deepcopy(sample)
221
        iterable = sample.get_images_dict(intensity_only=False).items()
222
        for image_name, image in iterable:
223
            cropped_sample[image_name] = copy.deepcopy(image)
224
            sample_image_dict = image
225
            cropped_image_dict = cropped_sample[image_name]
226
            cropped_image_dict[DATA] = self.crop(
227
                sample_image_dict[DATA], index_ini, index_fin)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
228
        # torch doesn't like uint16
229
        cropped_sample['index_ini'] = index_ini.astype(int)
230
        return cropped_sample
231
232
    @staticmethod
233
    def crop(
234
            data: Union[np.ndarray, torch.Tensor],
235
            index_ini: np.ndarray,
236
            index_fin: np.ndarray,
237
            ) -> Union[np.ndarray, torch.Tensor]:
238
        i_ini, j_ini, k_ini = index_ini
239
        assert np.all(np.array(index_fin) <= np.array(data.shape[1:]))
240
        i_fin, j_fin, k_fin = index_fin
241
        return data[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
242