|
1
|
|
|
from typing import Optional, Tuple, Generator |
|
2
|
|
|
|
|
3
|
|
|
import torch |
|
4
|
|
|
import numpy as np |
|
5
|
|
|
|
|
6
|
|
|
from ...typing import TypePatchSize |
|
7
|
|
|
from ..image import Image |
|
8
|
|
|
from ..subject import Subject |
|
9
|
|
|
from .sampler import RandomSampler |
|
10
|
|
|
|
|
11
|
|
|
|
|
12
|
|
|
class WeightedSampler(RandomSampler): |
|
13
|
|
|
r"""Randomly extract patches from a volume given a probability map. |
|
14
|
|
|
|
|
15
|
|
|
The probability of sampling a patch centered on a specific voxel is the |
|
16
|
|
|
value of that voxel in the probability map. The probabilities need not be |
|
17
|
|
|
normalized. For example, voxels can have values 0, 1 and 5. Voxels with |
|
18
|
|
|
value 0 will never be at the center of a patch. Voxels with value 5 will |
|
19
|
|
|
have 5 times more chance of being at the center of a patch that voxels |
|
20
|
|
|
with a value of 1. |
|
21
|
|
|
|
|
22
|
|
|
Args: |
|
23
|
|
|
patch_size: See :class:`~torchio.data.PatchSampler`. |
|
24
|
|
|
probability_map: Name of the image in the input subject that will be |
|
25
|
|
|
used as a sampling probability map. |
|
26
|
|
|
|
|
27
|
|
|
Raises: |
|
28
|
|
|
RuntimeError: If the probability map is empty. |
|
29
|
|
|
|
|
30
|
|
|
Example: |
|
31
|
|
|
>>> import torchio as tio |
|
32
|
|
|
>>> subject = tio.Subject( |
|
33
|
|
|
... t1=tio.ScalarImage('t1_mri.nii.gz'), |
|
34
|
|
|
... sampling_map=tio.Image('sampling.nii.gz', type=tio.SAMPLING_MAP), |
|
35
|
|
|
... ) |
|
36
|
|
|
>>> patch_size = 64 |
|
37
|
|
|
>>> sampler = tio.data.WeightedSampler(patch_size, 'sampling_map') |
|
38
|
|
|
>>> for patch in sampler(subject): |
|
39
|
|
|
... print(patch['index_ini']) |
|
40
|
|
|
|
|
41
|
|
|
.. note:: The index of the center of a patch with even size :math:`s` is |
|
42
|
|
|
arbitrarily set to :math:`s/2`. This is an implementation detail that |
|
43
|
|
|
will typically not make any difference in practice. |
|
44
|
|
|
|
|
45
|
|
|
.. note:: Values of the probability map near the border will be set to 0 as |
|
46
|
|
|
the center of the patch cannot be at the border (unless the patch has |
|
47
|
|
|
size 1 or 2 along that axis). |
|
48
|
|
|
|
|
49
|
|
|
""" |
|
50
|
|
|
def __init__( |
|
51
|
|
|
self, |
|
52
|
|
|
patch_size: TypePatchSize, |
|
53
|
|
|
probability_map: str, |
|
54
|
|
|
): |
|
55
|
|
|
super().__init__(patch_size) |
|
56
|
|
|
self.probability_map_name = probability_map |
|
57
|
|
|
self.cdf = None |
|
58
|
|
|
|
|
59
|
|
|
def __call__( |
|
60
|
|
|
self, |
|
61
|
|
|
subject: Subject, |
|
62
|
|
|
num_patches: Optional[int] = None, |
|
63
|
|
|
) -> Generator[Subject, None, None]: |
|
64
|
|
|
subject.check_consistent_space() |
|
65
|
|
|
if np.any(self.patch_size > subject.spatial_shape): |
|
66
|
|
|
message = ( |
|
67
|
|
|
f'Patch size {tuple(self.patch_size)} cannot be' |
|
68
|
|
|
f' larger than image size {tuple(subject.spatial_shape)}' |
|
69
|
|
|
) |
|
70
|
|
|
raise RuntimeError(message) |
|
71
|
|
|
probability_map = self.get_probability_map(subject) |
|
72
|
|
|
probability_map = self.process_probability_map(probability_map, subject) |
|
73
|
|
|
cdf = self.get_cumulative_distribution_function(probability_map) |
|
74
|
|
|
|
|
75
|
|
|
patches_left = num_patches if num_patches is not None else True |
|
76
|
|
|
while patches_left: |
|
77
|
|
|
yield self.extract_patch(subject, probability_map, cdf) |
|
78
|
|
|
if num_patches is not None: |
|
79
|
|
|
patches_left -= 1 |
|
80
|
|
|
|
|
81
|
|
|
def get_probability_map_image(self, subject: Subject) -> Image: |
|
82
|
|
|
if self.probability_map_name in subject: |
|
83
|
|
|
return subject[self.probability_map_name] |
|
84
|
|
|
else: |
|
85
|
|
|
message = ( |
|
86
|
|
|
f'Image "{self.probability_map_name}"' |
|
87
|
|
|
f' not found in subject: {subject}' |
|
88
|
|
|
) |
|
89
|
|
|
raise KeyError(message) |
|
90
|
|
|
|
|
91
|
|
|
def get_probability_map(self, subject: Subject) -> torch.Tensor: |
|
92
|
|
|
data = self.get_probability_map_image(subject).data |
|
93
|
|
|
if torch.any(data < 0): |
|
94
|
|
|
message = ( |
|
95
|
|
|
'Negative values found' |
|
96
|
|
|
f' in probability map "{self.probability_map_name}"' |
|
97
|
|
|
) |
|
98
|
|
|
raise ValueError(message) |
|
99
|
|
|
return data |
|
100
|
|
|
|
|
101
|
|
|
def process_probability_map( |
|
102
|
|
|
self, |
|
103
|
|
|
probability_map: torch.Tensor, |
|
104
|
|
|
subject: Subject, |
|
105
|
|
|
) -> np.ndarray: |
|
106
|
|
|
# Using float32 can create cdf with maximum very far from 1, e.g. 0.92! |
|
107
|
|
|
data = probability_map[0].numpy().astype(np.float64) |
|
108
|
|
|
assert data.ndim == 3 |
|
109
|
|
|
self.clear_probability_borders(data, self.patch_size) |
|
110
|
|
|
total = data.sum() |
|
111
|
|
|
if total == 0: |
|
112
|
|
|
message = ( |
|
113
|
|
|
'Empty probability map found:' |
|
114
|
|
|
f' {self.get_probability_map_image(subject).path}' |
|
115
|
|
|
) |
|
116
|
|
|
raise RuntimeError(message) |
|
117
|
|
|
data /= total # normalize probabilities |
|
118
|
|
|
return data |
|
119
|
|
|
|
|
120
|
|
|
@staticmethod |
|
121
|
|
|
def clear_probability_borders( |
|
122
|
|
|
probability_map: np.ndarray, |
|
123
|
|
|
patch_size: TypePatchSize, |
|
124
|
|
|
) -> None: |
|
125
|
|
|
# Set probability to 0 on voxels that wouldn't possibly be sampled given |
|
126
|
|
|
# the current patch size |
|
127
|
|
|
# We will arbitrarily define the center of an array with even length |
|
128
|
|
|
# using the // Python operator |
|
129
|
|
|
# For example, the center of an array (3, 4) will be on (1, 2) |
|
130
|
|
|
# |
|
131
|
|
|
# Patch center |
|
132
|
|
|
# . . . . . . . . |
|
133
|
|
|
# . . . . -> . . x . |
|
134
|
|
|
# . . . . . . . . |
|
135
|
|
|
# |
|
136
|
|
|
# |
|
137
|
|
|
# Prob. map After preprocessing |
|
138
|
|
|
# |
|
139
|
|
|
# x x x x x x x . . . . . . . |
|
140
|
|
|
# x x x x x x x . . x x x x . |
|
141
|
|
|
# x x x x x x x --> . . x x x x . |
|
142
|
|
|
# x x x x x x x --> . . x x x x . |
|
143
|
|
|
# x x x x x x x . . x x x x . |
|
144
|
|
|
# x x x x x x x . . . . . . . |
|
145
|
|
|
# |
|
146
|
|
|
# The dots represent removed probabilities, x mark possible locations |
|
147
|
|
|
crop_ini = patch_size // 2 |
|
148
|
|
|
crop_fin = (patch_size - 1) // 2 |
|
149
|
|
|
crop_i, crop_j, crop_k = crop_ini |
|
150
|
|
|
probability_map[:crop_i, :, :] = 0 |
|
151
|
|
|
probability_map[:, :crop_j, :] = 0 |
|
152
|
|
|
probability_map[:, :, :crop_k] = 0 |
|
153
|
|
|
|
|
154
|
|
|
# The call tolist() is very important. Using np.uint16 as negative index |
|
155
|
|
|
# will not work because e.g. -np.uint16(2) == 65534 |
|
156
|
|
|
crop_i, crop_j, crop_k = crop_fin.tolist() |
|
157
|
|
|
if crop_i: |
|
158
|
|
|
probability_map[-crop_i:, :, :] = 0 |
|
159
|
|
|
if crop_j: |
|
160
|
|
|
probability_map[:, -crop_j:, :] = 0 |
|
161
|
|
|
if crop_k: |
|
162
|
|
|
probability_map[:, :, -crop_k:] = 0 |
|
163
|
|
|
|
|
164
|
|
|
@staticmethod |
|
165
|
|
|
def get_cumulative_distribution_function( |
|
166
|
|
|
probability_map: np.ndarray, |
|
167
|
|
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
168
|
|
|
"""Return the cumulative distribution function of a probability map.""" |
|
169
|
|
|
flat_map = probability_map.flatten() |
|
170
|
|
|
flat_map_normalized = flat_map / flat_map.sum() |
|
171
|
|
|
cdf = np.cumsum(flat_map_normalized) |
|
172
|
|
|
return cdf |
|
173
|
|
|
|
|
174
|
|
|
def extract_patch( |
|
175
|
|
|
self, |
|
176
|
|
|
subject: Subject, |
|
177
|
|
|
probability_map: np.ndarray, |
|
178
|
|
|
cdf: np.ndarray |
|
179
|
|
|
) -> Subject: |
|
180
|
|
|
index_ini = self.get_random_index_ini(probability_map, cdf) |
|
181
|
|
|
cropped_subject = self.crop(subject, index_ini, self.patch_size) |
|
182
|
|
|
cropped_subject['index_ini'] = index_ini.astype(int) |
|
183
|
|
|
return cropped_subject |
|
184
|
|
|
|
|
185
|
|
|
def get_random_index_ini( |
|
186
|
|
|
self, |
|
187
|
|
|
probability_map: np.ndarray, |
|
188
|
|
|
cdf: np.ndarray |
|
189
|
|
|
) -> np.ndarray: |
|
190
|
|
|
center = self.sample_probability_map(probability_map, cdf) |
|
191
|
|
|
assert np.all(center >= 0) |
|
192
|
|
|
# See self.clear_probability_borders |
|
193
|
|
|
index_ini = center - self.patch_size // 2 |
|
194
|
|
|
assert np.all(index_ini >= 0) |
|
195
|
|
|
return index_ini |
|
196
|
|
|
|
|
197
|
|
|
@classmethod |
|
198
|
|
|
def sample_probability_map( |
|
199
|
|
|
cls, |
|
200
|
|
|
probability_map: np.ndarray, |
|
201
|
|
|
cdf: np.ndarray |
|
202
|
|
|
) -> np.ndarray: |
|
203
|
|
|
"""Inverse transform sampling. |
|
204
|
|
|
|
|
205
|
|
|
Example: |
|
206
|
|
|
>>> probability_map = np.array( |
|
207
|
|
|
... ((0,0,1,1,5,2,1,1,0), |
|
208
|
|
|
... (2,2,2,2,2,2,2,2,2))) |
|
209
|
|
|
>>> probability_map |
|
210
|
|
|
array([[0, 0, 1, 1, 5, 2, 1, 1, 0], |
|
211
|
|
|
[2, 2, 2, 2, 2, 2, 2, 2, 2]]) |
|
212
|
|
|
>>> histogram = np.zeros_like(probability_map) |
|
213
|
|
|
>>> for _ in range(100000): |
|
214
|
|
|
... histogram[WeightedSampler.sample_probability_map(probability_map, cdf)] += 1 # doctest:+SKIP |
|
215
|
|
|
... |
|
216
|
|
|
>>> histogram # doctest:+SKIP |
|
217
|
|
|
array([[ 0, 0, 3479, 3478, 17121, 7023, 3355, 3378, 0], |
|
218
|
|
|
[ 6808, 6804, 6942, 6809, 6946, 6988, 7002, 6826, 7041]]) |
|
219
|
|
|
|
|
220
|
|
|
""" |
|
221
|
|
|
# Get first value larger than random number |
|
222
|
|
|
random_number = torch.rand(1).item() |
|
223
|
|
|
# If probability map is float32, cdf.max() can be far from 1, e.g. 0.92 |
|
224
|
|
|
if random_number > cdf.max(): |
|
225
|
|
|
cdf_index = -1 |
|
226
|
|
|
else: # proceed as usual |
|
227
|
|
|
cdf_index = np.searchsorted(cdf, random_number) |
|
228
|
|
|
|
|
229
|
|
|
random_location_index = cdf_index |
|
230
|
|
|
center = np.unravel_index( |
|
231
|
|
|
random_location_index, |
|
232
|
|
|
probability_map.shape |
|
233
|
|
|
) |
|
234
|
|
|
|
|
235
|
|
|
i, j, k = center |
|
236
|
|
|
probability = probability_map[i, j, k] |
|
237
|
|
|
assert probability > 0 |
|
238
|
|
|
|
|
239
|
|
|
center = np.array(center).astype(int) |
|
240
|
|
|
return center |
|
241
|
|
|
|