1
|
|
|
""" |
2
|
|
|
Module containing utilities for layer inputs |
3
|
|
|
""" |
4
|
|
|
import itertools |
5
|
|
|
|
6
|
|
|
import numpy as np |
7
|
|
|
import tensorflow as tf |
8
|
|
|
|
9
|
|
|
|
10
|
|
|
def get_reference_grid(grid_size: (tuple, list)) -> tf.Tensor: |
11
|
|
|
""" |
12
|
|
|
Generate a 3D grid with given size. |
13
|
|
|
|
14
|
|
|
Reference: |
15
|
|
|
|
16
|
|
|
- volshape_to_meshgrid of neuron |
17
|
|
|
https://github.com/adalca/neurite/blob/legacy/neuron/utils.py |
18
|
|
|
|
19
|
|
|
neuron modifies meshgrid to make it faster, however local |
20
|
|
|
benchmark suggests tf.meshgrid is better |
21
|
|
|
|
22
|
|
|
Note: |
23
|
|
|
|
24
|
|
|
for tf.meshgrid, in the 3-D case with inputs of length M, N and P, |
25
|
|
|
outputs are of shape (N, M, P) for ‘xy’ indexing and |
26
|
|
|
(M, N, P) for ‘ij’ indexing. |
27
|
|
|
|
28
|
|
|
:param grid_size: list or tuple of size 3, [dim1, dim2, dim3] |
29
|
|
|
:return: shape = (dim1, dim2, dim3, 3), |
30
|
|
|
grid[i, j, k, :] = [i j k] |
31
|
|
|
""" |
32
|
|
|
|
33
|
|
|
# dim1, dim2, dim3 = grid_size |
34
|
|
|
# mesh_grid has three elements, corresponding to i, j, k |
35
|
|
|
# for i in range(dim1) |
36
|
|
|
# for j in range(dim2) |
37
|
|
|
# for k in range(dim3) |
38
|
|
|
# mesh_grid[0][i,j,k] = i |
39
|
|
|
# mesh_grid[1][i,j,k] = j |
40
|
|
|
# mesh_grid[2][i,j,k] = k |
41
|
|
|
mesh_grid = tf.meshgrid( |
42
|
|
|
tf.range(grid_size[0]), |
43
|
|
|
tf.range(grid_size[1]), |
44
|
|
|
tf.range(grid_size[2]), |
45
|
|
|
indexing="ij", |
46
|
|
|
) # has three elements, each shape = (dim1, dim2, dim3) |
47
|
|
|
grid = tf.stack(mesh_grid, axis=3) # shape = (dim1, dim2, dim3, 3) |
48
|
|
|
grid = tf.cast(grid, dtype=tf.float32) |
49
|
|
|
return grid |
50
|
|
|
|
51
|
|
|
|
52
|
|
|
def get_n_bits_combinations(num_bits: int) -> list: |
53
|
|
|
""" |
54
|
|
|
Function returning list containing all combinations of n bits. |
55
|
|
|
Given num_bits binary bits, each bit has value 0 or 1, |
56
|
|
|
there are in total 2**n_bits combinations. |
57
|
|
|
|
58
|
|
|
:param num_bits: int, number of combinations to evaluate |
59
|
|
|
:return: a list of length 2**n_bits, |
60
|
|
|
return[i] is the binary representation of the decimal integer. |
61
|
|
|
|
62
|
|
|
:Example: |
63
|
|
|
>>> from deepreg.model.layer_util import get_n_bits_combinations |
64
|
|
|
>>> get_n_bits_combinations(3) |
65
|
|
|
[[0, 0, 0], # 0 |
66
|
|
|
[0, 0, 1], # 1 |
67
|
|
|
[0, 1, 0], # 2 |
68
|
|
|
[0, 1, 1], # 3 |
69
|
|
|
[1, 0, 0], # 4 |
70
|
|
|
[1, 0, 1], # 5 |
71
|
|
|
[1, 1, 0], # 6 |
72
|
|
|
[1, 1, 1]] # 7 |
73
|
|
|
""" |
74
|
|
|
assert num_bits >= 1 |
75
|
|
|
return [list(i) for i in itertools.product([0, 1], repeat=num_bits)] |
76
|
|
|
|
77
|
|
|
|
78
|
|
|
def pyramid_combination( |
|
|
|
|
79
|
|
|
values: list, weight_floor: list, weight_ceil: list |
80
|
|
|
) -> tf.Tensor: |
81
|
|
|
r""" |
82
|
|
|
Calculates linear interpolation (a weighted sum) using values of |
83
|
|
|
hypercube corners in dimension n. |
84
|
|
|
|
85
|
|
|
For example, when num_dimension = len(loc_shape) = num_bits = 3 |
86
|
|
|
values correspond to values at corners of following coordinates |
87
|
|
|
|
88
|
|
|
.. code-block:: python |
89
|
|
|
|
90
|
|
|
[[0, 0, 0], # even |
91
|
|
|
[0, 0, 1], # odd |
92
|
|
|
[0, 1, 0], # even |
93
|
|
|
[0, 1, 1], # odd |
94
|
|
|
[1, 0, 0], # even |
95
|
|
|
[1, 0, 1], # odd |
96
|
|
|
[1, 1, 0], # even |
97
|
|
|
[1, 1, 1]] # odd |
98
|
|
|
|
99
|
|
|
values[::2] correspond to the corners with last coordinate == 0 |
100
|
|
|
|
101
|
|
|
.. code-block:: python |
102
|
|
|
|
103
|
|
|
[[0, 0, 0], |
104
|
|
|
[0, 1, 0], |
105
|
|
|
[1, 0, 0], |
106
|
|
|
[1, 1, 0]] |
107
|
|
|
|
108
|
|
|
values[1::2] correspond to the corners with last coordinate == 1 |
109
|
|
|
|
110
|
|
|
.. code-block:: python |
111
|
|
|
|
112
|
|
|
[[0, 0, 1], |
113
|
|
|
[0, 1, 1], |
114
|
|
|
[1, 0, 1], |
115
|
|
|
[1, 1, 1]] |
116
|
|
|
|
117
|
|
|
The weights correspond to the floor corners. |
118
|
|
|
For example, when num_dimension = len(loc_shape) = num_bits = 3, |
119
|
|
|
weight_floor = [f1, f2, f3] (ignoring the batch dimension). |
120
|
|
|
weight_ceil = [c1, c2, c3] (ignoring the batch dimension). |
121
|
|
|
|
122
|
|
|
So for corner with coords (x, y, z), x, y, z's values are 0 or 1 |
123
|
|
|
|
124
|
|
|
- weight for x = f1 if x = 0 else c1 |
125
|
|
|
- weight for y = f2 if y = 0 else c2 |
126
|
|
|
- weight for z = f3 if z = 0 else c3 |
127
|
|
|
|
128
|
|
|
so the weight for (x, y, z) is |
129
|
|
|
|
130
|
|
|
.. code-block:: text |
131
|
|
|
|
132
|
|
|
W_xyz = ((1-x) * f1 + x * c1) |
133
|
|
|
* ((1-y) * f2 + y * c2) |
134
|
|
|
* ((1-z) * f3 + z * c3) |
135
|
|
|
|
136
|
|
|
Let |
137
|
|
|
|
138
|
|
|
.. code-block:: text |
139
|
|
|
|
140
|
|
|
W_xy = ((1-x) * f1 + x * c1) |
141
|
|
|
* ((1-y) * f2 + y * c2) |
142
|
|
|
|
143
|
|
|
Then |
144
|
|
|
|
145
|
|
|
- W_xy0 = W_xy * f3 |
146
|
|
|
- W_xy1 = W_xy * c3 |
147
|
|
|
|
148
|
|
|
Similar to W_xyz, denote V_xyz the value at (x, y, z), |
149
|
|
|
the final sum V equals |
150
|
|
|
|
151
|
|
|
.. code-block:: text |
152
|
|
|
|
153
|
|
|
sum over x,y,z (V_xyz * W_xyz) |
154
|
|
|
= sum over x,y (V_xy0 * W_xy0 + V_xy1 * W_xy1) |
155
|
|
|
= sum over x,y (V_xy0 * W_xy * f3 + V_xy1 * W_xy * c3) |
156
|
|
|
= sum over x,y (V_xy0 * W_xy) * f3 + sum over x,y (V_xy1 * W_xy) * c3 |
157
|
|
|
|
158
|
|
|
That's why we call this pyramid combination. |
159
|
|
|
It calculates the linear interpolation gradually, starting from |
160
|
|
|
the last dimension. |
161
|
|
|
The key is that the weight of each corner is the product of the weights |
162
|
|
|
along each dimension. |
163
|
|
|
|
164
|
|
|
:param values: a list having values on the corner, |
165
|
|
|
it has 2**n tensors of shape |
166
|
|
|
(\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, ch) |
167
|
|
|
the order is consistent with get_n_bits_combinations |
168
|
|
|
loc_shape is independent from n, aka num_dim |
169
|
|
|
:param weight_floor: a list having weights of floor points, |
170
|
|
|
it has n tensors of shape |
171
|
|
|
(\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1) |
172
|
|
|
:param weight_ceil: a list having weights of ceil points, |
173
|
|
|
it has n tensors of shape |
174
|
|
|
(\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1) |
175
|
|
|
:return: one tensor of the same shape as an element in values |
176
|
|
|
(\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1) |
177
|
|
|
""" |
178
|
|
|
v_shape = values[0].shape |
179
|
|
|
wf_shape = weight_floor[0].shape |
180
|
|
|
wc_shape = weight_ceil[0].shape |
181
|
|
|
if len(v_shape) != len(wf_shape) or len(v_shape) != len(wc_shape): |
182
|
|
|
raise ValueError( |
183
|
|
|
"In pyramid_combination, elements of " |
184
|
|
|
"values, weight_floor, and weight_ceil should have same dimension. " |
185
|
|
|
f"value shape = {v_shape}, " |
186
|
|
|
f"weight_floor = {wf_shape}, " |
187
|
|
|
f"weight_ceil = {wc_shape}." |
188
|
|
|
) |
189
|
|
|
if 2 ** len(weight_floor) != len(values): |
190
|
|
|
raise ValueError( |
191
|
|
|
"In pyramid_combination, " |
192
|
|
|
"num_dim = len(weight_floor), " |
193
|
|
|
"len(values) must be 2 ** num_dim, " |
194
|
|
|
f"But len(weight_floor) = {len(weight_floor)}, " |
195
|
|
|
f"len(values) = {len(values)}" |
196
|
|
|
) |
197
|
|
|
|
198
|
|
|
if len(weight_floor) == 1: # one dimension |
199
|
|
|
return values[0] * weight_floor[0] + values[1] * weight_ceil[0] |
200
|
|
|
# multi dimension |
201
|
|
|
values_floor = pyramid_combination( |
202
|
|
|
values=values[::2], |
203
|
|
|
weight_floor=weight_floor[:-1], |
204
|
|
|
weight_ceil=weight_ceil[:-1], |
205
|
|
|
) |
206
|
|
|
values_floor = values_floor * weight_floor[-1] |
207
|
|
|
values_ceil = pyramid_combination( |
208
|
|
|
values=values[1::2], |
209
|
|
|
weight_floor=weight_floor[:-1], |
210
|
|
|
weight_ceil=weight_ceil[:-1], |
211
|
|
|
) |
212
|
|
|
values_ceil = values_ceil * weight_ceil[-1] |
213
|
|
|
return values_floor + values_ceil |
214
|
|
|
|
215
|
|
|
|
216
|
|
|
def resample( |
|
|
|
|
217
|
|
|
vol: tf.Tensor, |
218
|
|
|
loc: tf.Tensor, |
219
|
|
|
interpolation: str = "linear", |
220
|
|
|
zero_boundary: bool = True, |
221
|
|
|
) -> tf.Tensor: |
222
|
|
|
r""" |
223
|
|
|
Sample the volume at given locations. |
224
|
|
|
|
225
|
|
|
Input has |
226
|
|
|
|
227
|
|
|
- volume, vol, of shape = (batch, v_dim 1, ..., v_dim n), |
228
|
|
|
or (batch, v_dim 1, ..., v_dim n, ch), |
229
|
|
|
where n is the dimension of volume, |
230
|
|
|
ch is the extra dimension as features. |
231
|
|
|
|
232
|
|
|
Denote vol_shape = (v_dim 1, ..., v_dim n) |
233
|
|
|
|
234
|
|
|
- location, loc, of shape = (batch, l_dim 1, ..., l_dim m, n), |
235
|
|
|
where m is the dimension of output. |
236
|
|
|
|
237
|
|
|
Denote loc_shape = (l_dim 1, ..., l_dim m) |
238
|
|
|
|
239
|
|
|
Reference: |
240
|
|
|
|
241
|
|
|
- neuron's interpn |
242
|
|
|
https://github.com/adalca/neurite/blob/legacy/neuron/utils.py |
243
|
|
|
|
244
|
|
|
Difference |
245
|
|
|
|
246
|
|
|
1. they dont have batch size |
247
|
|
|
2. they support more dimensions in vol |
248
|
|
|
|
249
|
|
|
TODO try not using stack as neuron claims it's slower |
250
|
|
|
|
251
|
|
|
:param vol: shape = (batch, \*vol_shape) or (batch, \*vol_shape, ch) |
252
|
|
|
with the last channel for features |
253
|
|
|
:param loc: shape = (batch, \*loc_shape, n) |
254
|
|
|
such that loc[b, l1, ..., lm, :] = [v1, ..., vn] is of shape (n,), |
255
|
|
|
which represents a point in vol, with coordinates (v1, ..., vn) |
256
|
|
|
:param interpolation: linear only, TODO support nearest |
257
|
|
|
:param zero_boundary: if true, values on or outside boundary will be zeros |
258
|
|
|
:return: shape = (batch, \*loc_shape) or (batch, \*loc_shape, ch) |
259
|
|
|
""" |
260
|
|
|
|
261
|
|
|
if interpolation != "linear": |
262
|
|
|
raise ValueError("resample supports only linear interpolation") |
263
|
|
|
|
264
|
|
|
# init |
265
|
|
|
batch_size = vol.shape[0] |
266
|
|
|
loc_shape = loc.shape[1:-1] |
267
|
|
|
dim_vol = loc.shape[-1] # dimension of vol, n |
268
|
|
|
if dim_vol == len(vol.shape) - 1: |
269
|
|
|
# vol.shape = (batch, *vol_shape) |
270
|
|
|
has_ch = False |
271
|
|
|
elif dim_vol == len(vol.shape) - 2: |
272
|
|
|
# vol.shape = (batch, *vol_shape, ch) |
273
|
|
|
has_ch = True |
274
|
|
|
else: |
275
|
|
|
raise ValueError( |
276
|
|
|
"vol shape inconsistent with loc " |
277
|
|
|
"vol.shape = {}, loc.shape = {}".format(vol.shape, loc.shape) |
278
|
|
|
) |
279
|
|
|
vol_shape = vol.shape[1 : dim_vol + 1] |
280
|
|
|
|
281
|
|
|
# get floor/ceil for loc and stack, then clip together |
282
|
|
|
# loc, loc_floor, loc_ceil are have shape (batch, *loc_shape, n) |
283
|
|
|
loc_ceil = tf.math.ceil(loc) |
284
|
|
|
loc_floor = loc_ceil - 1 |
285
|
|
|
# (batch, *loc_shape, n, 3) |
286
|
|
|
clipped = tf.stack([loc, loc_floor, loc_ceil], axis=-1) |
287
|
|
|
clip_value_max = tf.cast(vol_shape, dtype=clipped.dtype) - 1 # (n,) |
288
|
|
|
clipped_shape = [1] * (len(loc_shape) + 1) + [dim_vol, 1] |
289
|
|
|
clip_value_max = tf.reshape(clip_value_max, shape=clipped_shape) |
290
|
|
|
clipped = tf.clip_by_value(clipped, clip_value_min=0, clip_value_max=clip_value_max) |
291
|
|
|
|
292
|
|
|
# loc_floor_ceil has n sublists |
293
|
|
|
# each one corresponds to the floor and ceil coordinates for d-th dimension |
294
|
|
|
# each tensor is of shape (batch, *loc_shape), dtype int32 |
295
|
|
|
|
296
|
|
|
# weight_floor has n tensors |
297
|
|
|
# each tensor is the weight for the corner of floor coordinates |
298
|
|
|
# each tensor's shape is (batch, *loc_shape) if volume has no feature channel |
299
|
|
|
# (batch, *loc_shape, 1) if volume has feature channel |
300
|
|
|
loc_floor_ceil, weight_floor, weight_ceil = [], [], [] |
301
|
|
|
# using for loop is faster than using list comprehension |
302
|
|
|
for dim in range(dim_vol): |
303
|
|
|
# shape = (batch, *loc_shape) |
304
|
|
|
c_clipped = clipped[..., dim, 0] |
305
|
|
|
c_floor = clipped[..., dim, 1] |
306
|
|
|
c_ceil = clipped[..., dim, 2] |
307
|
|
|
w_floor = c_ceil - c_clipped # shape = (batch, *loc_shape) |
308
|
|
|
w_ceil = c_clipped - c_floor if zero_boundary else 1 - w_floor |
309
|
|
|
if has_ch: |
310
|
|
|
w_floor = tf.expand_dims(w_floor, -1) # shape = (batch, *loc_shape, 1) |
311
|
|
|
w_ceil = tf.expand_dims(w_ceil, -1) # shape = (batch, *loc_shape, 1) |
312
|
|
|
|
313
|
|
|
loc_floor_ceil.append([tf.cast(c_floor, tf.int32), tf.cast(c_ceil, tf.int32)]) |
314
|
|
|
weight_floor.append(w_floor) |
315
|
|
|
weight_ceil.append(w_ceil) |
316
|
|
|
|
317
|
|
|
# 2**n corners, each is a list of n binary values |
318
|
|
|
corner_indices = get_n_bits_combinations(num_bits=len(vol_shape)) |
319
|
|
|
|
320
|
|
|
# batch_coords[b, l1, ..., lm] = b |
321
|
|
|
# range(batch_size) on axis 0 and repeated on other axes |
322
|
|
|
# add batch coords manually is faster than using batch_dims in tf.gather_nd |
323
|
|
|
batch_coords = tf.tile( |
324
|
|
|
tf.reshape(tf.range(batch_size), [batch_size] + [1] * len(loc_shape)), |
325
|
|
|
[1] + loc_shape, |
326
|
|
|
) # shape = (batch, *loc_shape) |
327
|
|
|
|
328
|
|
|
# get vol values on n-dim hypercube corners |
329
|
|
|
# corner_values has 2 ** n elements |
330
|
|
|
# each of shape (batch, *loc_shape) or (batch, *loc_shape, ch) |
331
|
|
|
corner_values = [ |
332
|
|
|
tf.gather_nd( |
333
|
|
|
vol, # shape = (batch, *vol_shape) or (batch, *vol_shape, ch) |
334
|
|
|
tf.stack( |
335
|
|
|
[batch_coords] |
336
|
|
|
+ [loc_floor_ceil[axis][fc_idx] for axis, fc_idx in enumerate(c)], |
337
|
|
|
axis=-1, |
338
|
|
|
), # shape = (batch, *loc_shape, n+1) after stack |
339
|
|
|
) |
340
|
|
|
for c in corner_indices # c is list of len n |
341
|
|
|
] # each tensor has shape (batch, *loc_shape) or (batch, *loc_shape, ch) |
342
|
|
|
|
343
|
|
|
# resample |
344
|
|
|
sampled = pyramid_combination( |
345
|
|
|
values=corner_values, weight_floor=weight_floor, weight_ceil=weight_ceil |
346
|
|
|
) |
347
|
|
|
return sampled |
348
|
|
|
|
349
|
|
|
|
350
|
|
|
def warp_grid(grid: tf.Tensor, theta: tf.Tensor) -> tf.Tensor: |
351
|
|
|
""" |
352
|
|
|
Perform transformation on the grid. |
353
|
|
|
|
354
|
|
|
- grid_padded[i,j,k,:] = [i j k 1] |
355
|
|
|
- grid_warped[b,i,j,k,p] = sum_over_q (grid_padded[i,j,k,q] * theta[b,q,p]) |
356
|
|
|
|
357
|
|
|
:param grid: shape = (dim1, dim2, dim3, 3), grid[i,j,k,:] = [i j k] |
358
|
|
|
:param theta: parameters of transformation, shape = (batch, 4, 3) |
359
|
|
|
:return: shape = (batch, dim1, dim2, dim3, 3) |
360
|
|
|
""" |
361
|
|
|
|
362
|
|
|
# grid_padded[i,j,k,:] = [i j k 1], shape = (dim1, dim2, dim3, 4) |
363
|
|
|
grid_padded = tf.concat([grid, tf.ones_like(grid[..., :1])], axis=3) |
364
|
|
|
|
365
|
|
|
# grid_warped[b,i,j,k,p] = sum_over_q (grid_padded[i,j,k,q] * theta[b,q,p]) |
366
|
|
|
# shape = (batch, dim1, dim2, dim3, 3) |
367
|
|
|
grid_warped = tf.einsum("ijkq,bqp->bijkp", grid_padded, theta) |
368
|
|
|
return grid_warped |
369
|
|
|
|
370
|
|
|
|
371
|
|
|
def resize3d( |
|
|
|
|
372
|
|
|
image: tf.Tensor, size: (tuple, list), method: str = tf.image.ResizeMethod.BILINEAR |
373
|
|
|
) -> tf.Tensor: |
374
|
|
|
""" |
375
|
|
|
Tensorflow does not have resize 3d, therefore the resize is performed two folds. |
376
|
|
|
|
377
|
|
|
- resize dim2 and dim3 |
378
|
|
|
- resize dim1 and dim2 |
379
|
|
|
|
380
|
|
|
:param image: tensor of shape = (batch, dim1, dim2, dim3, channels) |
381
|
|
|
or (batch, dim1, dim2, dim3) |
382
|
|
|
or (dim1, dim2, dim3) |
383
|
|
|
:param size: tuple, (out_dim1, out_dim2, out_dim3) |
384
|
|
|
:param method: str, one of tf.image.ResizeMethod |
385
|
|
|
:return: tensor of shape = (batch, out_dim1, out_dim2, out_dim3, channels) |
386
|
|
|
or (batch, dim1, dim2, dim3) |
387
|
|
|
or (dim1, dim2, dim3) |
388
|
|
|
""" |
389
|
|
|
# sanity check |
390
|
|
|
image_dim = len(image.shape) |
391
|
|
|
if image_dim not in [3, 4, 5]: |
392
|
|
|
raise ValueError( |
393
|
|
|
"resize3d takes input image of dimension 3 or 4 or 5," |
394
|
|
|
"corresponding to (dim1, dim2, dim3) " |
395
|
|
|
"or (batch, dim1, dim2, dim3)" |
396
|
|
|
"or (batch, dim1, dim2, dim3, channels)," |
397
|
|
|
"got image shape{}".format(image.shape) |
398
|
|
|
) |
399
|
|
|
if len(size) != 3: |
400
|
|
|
raise ValueError("resize3d takes size of type tuple/list and of length 3") |
401
|
|
|
|
402
|
|
|
# init |
403
|
|
|
if image_dim == 5: |
404
|
|
|
has_channel = True |
405
|
|
|
has_batch = True |
406
|
|
|
input_image_shape = image.shape[1:4] |
407
|
|
|
elif image_dim == 4: |
408
|
|
|
has_channel = False |
409
|
|
|
has_batch = True |
410
|
|
|
input_image_shape = image.shape[1:4] |
411
|
|
|
else: |
412
|
|
|
has_channel = False |
413
|
|
|
has_batch = False |
414
|
|
|
input_image_shape = image.shape[0:3] |
415
|
|
|
|
416
|
|
|
# no need of resize |
417
|
|
|
if input_image_shape == tuple(size): |
418
|
|
|
return image |
419
|
|
|
|
420
|
|
|
# expand to five dimensions |
421
|
|
|
if not has_batch: |
422
|
|
|
image = tf.expand_dims(image, axis=0) |
423
|
|
|
if not has_channel: |
424
|
|
|
image = tf.expand_dims(image, axis=-1) |
425
|
|
|
assert len(image.shape) == 5 # (batch, dim1, dim2, dim3, channels) |
426
|
|
|
image_shape = tf.shape(image) |
427
|
|
|
|
428
|
|
|
# merge axis 0 and 1 |
429
|
|
|
output = tf.reshape( |
430
|
|
|
image, (-1, image_shape[2], image_shape[3], image_shape[4]) |
431
|
|
|
) # (batch * dim1, dim2, dim3, channels) |
432
|
|
|
|
433
|
|
|
# resize dim2 and dim3 |
434
|
|
|
output = tf.image.resize( |
435
|
|
|
images=output, size=size[1:], method=method |
436
|
|
|
) # (batch * dim1, out_dim2, out_dim3, channels) |
437
|
|
|
|
438
|
|
|
# split axis 0 and merge axis 3 and 4 |
439
|
|
|
output = tf.reshape( |
440
|
|
|
output, shape=(-1, image_shape[1], size[1], size[2] * image_shape[4]) |
441
|
|
|
) # (batch, dim1, out_dim2, out_dim3 * channels) |
442
|
|
|
|
443
|
|
|
# resize dim1 and dim2 |
444
|
|
|
output = tf.image.resize( |
445
|
|
|
images=output, size=size[:2], method=method |
446
|
|
|
) # (batch, out_dim1, out_dim2, out_dim3 * channels) |
447
|
|
|
|
448
|
|
|
# reshape |
449
|
|
|
output = tf.reshape( |
450
|
|
|
output, shape=[-1, *size, image_shape[4]] |
451
|
|
|
) # (batch, out_dim1, out_dim2, out_dim3, channels) |
452
|
|
|
|
453
|
|
|
# squeeze to original dimension |
454
|
|
|
if not has_batch: |
455
|
|
|
output = tf.squeeze(output, axis=0) |
456
|
|
|
if not has_channel: |
457
|
|
|
output = tf.squeeze(output, axis=-1) |
458
|
|
|
return output |
459
|
|
|
|
460
|
|
|
|
461
|
|
|
def gaussian_filter_3d(kernel_sigma: (list, tuple, int)) -> tf.Tensor: |
462
|
|
|
""" |
463
|
|
|
Define a gaussian filter in 3d for smoothing. |
464
|
|
|
|
465
|
|
|
The filter size is defined 3*kernel_sigma |
466
|
|
|
|
467
|
|
|
|
468
|
|
|
:param kernel_sigma: the deviation at each direction (list) |
469
|
|
|
or use an isotropic deviation (int) |
470
|
|
|
:return: kernel: tf.Tensor specify a gaussian kernel of shape: |
471
|
|
|
[3*k for k in kernel_sigma] |
472
|
|
|
""" |
473
|
|
|
if isinstance(kernel_sigma, int): |
474
|
|
|
kernel_sigma = (kernel_sigma, kernel_sigma, kernel_sigma) |
475
|
|
|
|
476
|
|
|
kernel_size = [ |
477
|
|
|
int(np.ceil(ks * 3) + np.mod(np.ceil(ks * 3) + 1, 2)) for ks in kernel_sigma |
478
|
|
|
] |
479
|
|
|
|
480
|
|
|
# Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) |
481
|
|
|
coord = [np.arange(ks) for ks in kernel_size] |
482
|
|
|
|
483
|
|
|
xx, yy, zz = np.meshgrid(coord[0], coord[1], coord[2], indexing="ij") |
484
|
|
|
xyz_grid = np.concatenate( |
485
|
|
|
(xx[np.newaxis], yy[np.newaxis], zz[np.newaxis]), axis=0 |
486
|
|
|
) # 2, y, x |
487
|
|
|
|
488
|
|
|
mean = np.asarray([(ks - 1) / 2.0 for ks in kernel_size]) |
489
|
|
|
mean = mean.reshape(-1, 1, 1, 1) |
490
|
|
|
variance = np.asarray([ks ** 2.0 for ks in kernel_sigma]) |
491
|
|
|
variance = variance.reshape(-1, 1, 1, 1) |
492
|
|
|
|
493
|
|
|
# Calculate the 2-dimensional gaussian kernel which is |
494
|
|
|
# the product of two gaussian distributions for two different |
495
|
|
|
# variables (in this case called x and y) |
496
|
|
|
# 2.506628274631 = sqrt(2 * pi) |
497
|
|
|
|
498
|
|
|
norm_kernel = 1.0 / (np.sqrt(2 * np.pi) ** 3 + np.prod(kernel_sigma)) |
499
|
|
|
kernel = norm_kernel * np.exp( |
500
|
|
|
-np.sum((xyz_grid - mean) ** 2.0 / (2 * variance), axis=0) |
501
|
|
|
) |
502
|
|
|
|
503
|
|
|
# Make sure sum of values in gaussian kernel equals 1. |
504
|
|
|
kernel = kernel / np.sum(kernel) |
505
|
|
|
|
506
|
|
|
# Reshape |
507
|
|
|
kernel = kernel.reshape(kernel_size[0], kernel_size[1], kernel_size[2]) |
508
|
|
|
|
509
|
|
|
# Total kernel |
510
|
|
|
total_kernel = np.zeros(tuple(kernel_size) + (3, 3)) |
511
|
|
|
total_kernel[..., 0, 0] = kernel |
512
|
|
|
total_kernel[..., 1, 1] = kernel |
513
|
|
|
total_kernel[..., 2, 2] = kernel |
514
|
|
|
|
515
|
|
|
return tf.convert_to_tensor(total_kernel, dtype=tf.float32) |
516
|
|
|
|