|
1
|
|
|
import copy |
|
2
|
|
|
from typing import Union, Sequence, Generator, Tuple, Optional |
|
3
|
|
|
|
|
4
|
|
|
import numpy as np |
|
5
|
|
|
|
|
6
|
|
|
import torch |
|
7
|
|
|
from torch.utils.data import IterableDataset |
|
8
|
|
|
|
|
9
|
|
|
from ...torchio import DATA |
|
10
|
|
|
from ...utils import to_tuple |
|
11
|
|
|
from ..subject import Subject |
|
12
|
|
|
|
|
13
|
|
|
|
|
14
|
|
|
|
|
15
|
|
|
class WeightedSampler(IterableDataset): |
|
16
|
|
|
r"""Extract random patches from a volume. |
|
17
|
|
|
|
|
18
|
|
|
Args: |
|
19
|
|
|
sample: Sample generated by a |
|
20
|
|
|
:py:class:`~torchio.data.dataset.ImagesDataset`, from which image |
|
21
|
|
|
patches will be extracted. |
|
22
|
|
|
patch_size: Tuple of integers :math:`(d, h, w)` to generate patches |
|
23
|
|
|
of size :math:`d \times h \times w`. |
|
24
|
|
|
If a single number :math:`n` is provided, :math:`d = h = w = n`. |
|
25
|
|
|
probability_map_name: Name of the image in the sample that will be used |
|
26
|
|
|
as a probability map. |
|
27
|
|
|
""" |
|
28
|
|
|
def __init__( |
|
29
|
|
|
self, |
|
30
|
|
|
sample: Subject, |
|
31
|
|
|
patch_size: Union[int, Sequence[int]], |
|
32
|
|
|
probability_map_name: Optional[str] = None, |
|
33
|
|
|
): |
|
34
|
|
|
sample.check_consistent_shape() |
|
35
|
|
|
self.sample = sample |
|
36
|
|
|
patch_size = to_tuple(patch_size, length=3) |
|
37
|
|
|
self.patch_size = np.array(patch_size, dtype=np.uint16) |
|
38
|
|
|
if np.any(self.patch_size > sample.spatial_shape): |
|
39
|
|
|
message = ( |
|
40
|
|
|
f'Patch size {tuple(self.patch_size)} cannot be' |
|
41
|
|
|
f' larger than image size {tuple(sample.spatial_shape)}' |
|
42
|
|
|
) |
|
43
|
|
|
raise ValueError(message) |
|
44
|
|
|
self.probability_map = self.process_probability_map( |
|
45
|
|
|
probability_map_name) |
|
46
|
|
|
self.cdf, self.sort_indices = self.get_cumulative_distribution_function( |
|
47
|
|
|
self.probability_map) |
|
48
|
|
|
|
|
49
|
|
|
def __iter__(self) -> Generator[Subject, None, None]: |
|
50
|
|
|
while True: |
|
51
|
|
|
yield self.extract_patch() |
|
52
|
|
|
|
|
53
|
|
|
def process_probability_map(self, probability_map_name): |
|
54
|
|
|
if probability_map_name in self.sample: |
|
55
|
|
|
data = self.sample[probability_map_name].data.copy() |
|
56
|
|
|
else: |
|
57
|
|
|
data = torch.ones(self.sample.shape) |
|
58
|
|
|
# Using float32 creates cdf with maximum very far from 1, e.g. 0.92! |
|
59
|
|
|
data = data[0].numpy().astype(np.float64) |
|
60
|
|
|
assert data.ndim == 3 |
|
61
|
|
|
if np.any(data < 0): |
|
62
|
|
|
message = ( |
|
63
|
|
|
'Negative values found' |
|
64
|
|
|
f' in probability map "{probability_map_name}"' |
|
65
|
|
|
) |
|
66
|
|
|
raise ValueError(message) |
|
67
|
|
|
if data.sum() == 0: # although it should not be empty |
|
68
|
|
|
data += 1 # make uniform |
|
69
|
|
|
data /= data.sum() # normalize probabilities |
|
70
|
|
|
self.clear_probability_borders(data, self.patch_size) |
|
71
|
|
|
return data |
|
72
|
|
|
|
|
73
|
|
|
@staticmethod |
|
74
|
|
|
def clear_probability_borders(probability_map, patch_size): |
|
75
|
|
|
# Set probability to 0 on voxels that wouldn't possibly be sampled given |
|
76
|
|
|
# the current patch size |
|
77
|
|
|
# We will arbitrarily define the center of an array with even length |
|
78
|
|
|
# using the // Python operator |
|
79
|
|
|
# For example, the center of an array (3, 4) will be on (1, 2) |
|
80
|
|
|
# |
|
81
|
|
|
# . . . . . . . . |
|
82
|
|
|
# . . . . -> . . x . |
|
83
|
|
|
# . . . . . . . . |
|
84
|
|
|
# |
|
85
|
|
|
# x x x x x x x . . . . . . . |
|
86
|
|
|
# x x x x x x x . . x x x x . |
|
87
|
|
|
# x x x x x x x --> . . x x x x . |
|
88
|
|
|
# x x x x x x x --> . . x x x x . |
|
89
|
|
|
# x x x x x x x . . x x x x . |
|
90
|
|
|
# x x x x x x x . . . . . . . |
|
91
|
|
|
# |
|
92
|
|
|
# The dots represent removed probabilities, x mark possible locations |
|
93
|
|
|
|
|
94
|
|
|
crop_i, crop_j, crop_k = crop = np.array(patch_size) // 2 |
|
95
|
|
|
probability_map[:crop_i, :, :] = 0 |
|
96
|
|
|
probability_map[:, :crop_j, :] = 0 |
|
97
|
|
|
probability_map[:, :, :crop_k] = 0 |
|
98
|
|
|
|
|
99
|
|
|
# Subtract 1 to even numbers |
|
100
|
|
|
crop_i, crop_j, crop_k = [n - (n + 1) % 2 if n > 0 else n for n in crop] |
|
101
|
|
|
if crop_i: |
|
102
|
|
|
probability_map[-crop_i:, :, :] = 0 |
|
103
|
|
|
if crop_j: |
|
104
|
|
|
probability_map[:, -crop_j:, :] = 0 |
|
105
|
|
|
if crop_k: |
|
106
|
|
|
probability_map[:, :, -crop_k:] = 0 |
|
107
|
|
|
|
|
108
|
|
|
def get_random_index_ini(self): |
|
109
|
|
|
center = self.sample_probability_map() |
|
110
|
|
|
|
|
111
|
|
|
# See self.clear_probability_borders |
|
112
|
|
|
index_ini = center - self.patch_size // 2 |
|
113
|
|
|
assert np.all(index_ini >= 0) |
|
114
|
|
|
return index_ini |
|
115
|
|
|
|
|
116
|
|
|
@staticmethod |
|
117
|
|
|
def get_cumulative_distribution_function(probability_map): |
|
118
|
|
|
# Get the sorting indices to that we can invert the sorting later on |
|
119
|
|
|
flat_map = probability_map.flatten() |
|
120
|
|
|
flat_map_normalized = flat_map / flat_map.sum() |
|
121
|
|
|
sort_indices = np.argsort(flat_map_normalized) |
|
122
|
|
|
flat_map_normalized_sorted = flat_map[sort_indices] |
|
123
|
|
|
cdf = np.cumsum(flat_map_normalized_sorted) |
|
124
|
|
|
return cdf, sort_indices |
|
125
|
|
|
|
|
126
|
|
|
def sample_probability_map(self): |
|
127
|
|
|
"""Inverse transform sampling. |
|
128
|
|
|
|
|
129
|
|
|
Example: |
|
130
|
|
|
>>> probability_map = np.array( |
|
131
|
|
|
... ((0,0,1,1,5,2,1,1,0), |
|
132
|
|
|
... (2,2,2,2,2,2,2,2,2))) |
|
133
|
|
|
>>> probability_map |
|
134
|
|
|
array([[0, 0, 1, 1, 5, 2, 1, 1, 0], |
|
135
|
|
|
[2, 2, 2, 2, 2, 2, 2, 2, 2]]) |
|
136
|
|
|
>>> histogram = np.zeros_like(probability_map) |
|
137
|
|
|
>>> for _ in range(100000): |
|
138
|
|
|
... histogram[sample_probability_map(probability_map)] += 1 |
|
139
|
|
|
... |
|
140
|
|
|
>>> histogram |
|
141
|
|
|
array([[ 0, 0, 3479, 3478, 17121, 7023, 3355, 3378, 0], |
|
142
|
|
|
[ 6808, 6804, 6942, 6809, 6946, 6988, 7002, 6826, 7041]]) |
|
143
|
|
|
|
|
144
|
|
|
""" |
|
145
|
|
|
# Get first value larger than random number |
|
146
|
|
|
random_number = torch.rand(1).item() |
|
147
|
|
|
# If probability map is float32, cdf.max() can be far from 1, e.g. 0.92 |
|
148
|
|
|
if random_number > self.cdf.max(): |
|
149
|
|
|
cdf_index = -1 |
|
150
|
|
|
else: # proceed as usual |
|
151
|
|
|
cdf_index = np.argmax(random_number < self.cdf) |
|
152
|
|
|
|
|
153
|
|
|
random_location_index = self.sort_indices[cdf_index] |
|
154
|
|
|
center = np.unravel_index( |
|
155
|
|
|
random_location_index, |
|
156
|
|
|
self.probability_map.shape |
|
157
|
|
|
) |
|
158
|
|
|
center = np.array(center).astype(int) |
|
159
|
|
|
return center |
|
160
|
|
|
|
|
161
|
|
|
def extract_patch(self) -> Subject: |
|
162
|
|
|
# TODO: replace with Crop transform |
|
163
|
|
|
index_ini = self.get_random_index_ini() |
|
164
|
|
|
cropped_sample = self.copy_and_crop(index_ini) |
|
165
|
|
|
return cropped_sample |
|
166
|
|
|
|
|
167
|
|
View Code Duplication |
def copy_and_crop(self, index_ini: np.ndarray) -> dict: |
|
|
|
|
|
|
168
|
|
|
index_fin = index_ini + self.patch_size |
|
169
|
|
|
cropped_sample = copy.deepcopy(self.sample) |
|
170
|
|
|
iterable = self.sample.get_images_dict(intensity_only=False).items() |
|
171
|
|
|
for image_name, image in iterable: |
|
172
|
|
|
cropped_sample[image_name] = copy.deepcopy(image) |
|
173
|
|
|
sample_image_dict = image |
|
174
|
|
|
cropped_image_dict = cropped_sample[image_name] |
|
175
|
|
|
cropped_image_dict[DATA] = self.crop( |
|
176
|
|
|
sample_image_dict[DATA], index_ini, index_fin) |
|
|
|
|
|
|
177
|
|
|
# torch doesn't like uint16 |
|
178
|
|
|
cropped_sample['index_ini'] = index_ini.astype(int) |
|
179
|
|
|
return cropped_sample |
|
180
|
|
|
|
|
181
|
|
|
@staticmethod |
|
182
|
|
|
def crop( |
|
183
|
|
|
data: Union[np.ndarray, torch.Tensor], |
|
184
|
|
|
index_ini: np.ndarray, |
|
185
|
|
|
index_fin: np.ndarray, |
|
186
|
|
|
) -> Union[np.ndarray, torch.Tensor]: |
|
187
|
|
|
i_ini, j_ini, k_ini = index_ini |
|
188
|
|
|
i_fin, j_fin, k_fin = index_fin |
|
189
|
|
|
return data[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] |
|
190
|
|
|
|