1
|
|
|
""" |
2
|
|
|
Module containing utilities for layer inputs |
3
|
|
|
""" |
4
|
|
|
import itertools |
5
|
|
|
from typing import List, Tuple, Union |
6
|
|
|
|
7
|
|
|
import numpy as np |
8
|
|
|
import tensorflow as tf |
9
|
|
|
|
10
|
|
|
|
11
|
|
|
def get_reference_grid(grid_size: Union[Tuple[int, ...], List[int]]) -> tf.Tensor: |
12
|
|
|
""" |
13
|
|
|
Generate a 3D grid with given size. |
14
|
|
|
|
15
|
|
|
Reference: |
16
|
|
|
|
17
|
|
|
- volshape_to_meshgrid of neuron |
18
|
|
|
https://github.com/adalca/neurite/blob/legacy/neuron/utils.py |
19
|
|
|
|
20
|
|
|
neuron modifies meshgrid to make it faster, however local |
21
|
|
|
benchmark suggests tf.meshgrid is better |
22
|
|
|
|
23
|
|
|
Note: |
24
|
|
|
|
25
|
|
|
for tf.meshgrid, in the 3-D case with inputs of length M, N and P, |
26
|
|
|
outputs are of shape (N, M, P) for ‘xy’ indexing and |
27
|
|
|
(M, N, P) for ‘ij’ indexing. |
28
|
|
|
|
29
|
|
|
:param grid_size: list or tuple of size 3, [dim1, dim2, dim3] |
30
|
|
|
:return: shape = (dim1, dim2, dim3, 3), |
31
|
|
|
grid[i, j, k, :] = [i j k] |
32
|
|
|
""" |
33
|
|
|
|
34
|
|
|
# dim1, dim2, dim3 = grid_size |
35
|
|
|
# mesh_grid has three elements, corresponding to i, j, k |
36
|
|
|
# for i in range(dim1) |
37
|
|
|
# for j in range(dim2) |
38
|
|
|
# for k in range(dim3) |
39
|
|
|
# mesh_grid[0][i,j,k] = i |
40
|
|
|
# mesh_grid[1][i,j,k] = j |
41
|
|
|
# mesh_grid[2][i,j,k] = k |
42
|
|
|
mesh_grid = tf.meshgrid( |
43
|
|
|
tf.range(grid_size[0]), |
44
|
|
|
tf.range(grid_size[1]), |
45
|
|
|
tf.range(grid_size[2]), |
46
|
|
|
indexing="ij", |
47
|
|
|
) # has three elements, each shape = (dim1, dim2, dim3) |
48
|
|
|
grid = tf.stack(mesh_grid, axis=3) # shape = (dim1, dim2, dim3, 3) |
49
|
|
|
grid = tf.cast(grid, dtype=tf.float32) |
50
|
|
|
return grid |
51
|
|
|
|
52
|
|
|
|
53
|
|
|
def get_n_bits_combinations(num_bits: int) -> List[List[int]]: |
54
|
|
|
""" |
55
|
|
|
Function returning list containing all combinations of n bits. |
56
|
|
|
Given num_bits binary bits, each bit has value 0 or 1, |
57
|
|
|
there are in total 2**n_bits combinations. |
58
|
|
|
|
59
|
|
|
:param num_bits: int, number of combinations to evaluate |
60
|
|
|
:return: a list of length 2**n_bits, |
61
|
|
|
return[i] is the binary representation of the decimal integer. |
62
|
|
|
|
63
|
|
|
:Example: |
64
|
|
|
>>> from deepreg.model.layer_util import get_n_bits_combinations |
65
|
|
|
>>> get_n_bits_combinations(3) |
66
|
|
|
[[0, 0, 0], # 0 |
67
|
|
|
[0, 0, 1], # 1 |
68
|
|
|
[0, 1, 0], # 2 |
69
|
|
|
[0, 1, 1], # 3 |
70
|
|
|
[1, 0, 0], # 4 |
71
|
|
|
[1, 0, 1], # 5 |
72
|
|
|
[1, 1, 0], # 6 |
73
|
|
|
[1, 1, 1]] # 7 |
74
|
|
|
""" |
75
|
|
|
assert num_bits >= 1 |
76
|
|
|
return [list(i) for i in itertools.product([0, 1], repeat=num_bits)] |
77
|
|
|
|
78
|
|
|
|
79
|
|
|
def pyramid_combination( |
|
|
|
|
80
|
|
|
values: list, weight_floor: list, weight_ceil: list |
81
|
|
|
) -> tf.Tensor: |
82
|
|
|
r""" |
83
|
|
|
Calculates linear interpolation (a weighted sum) using values of |
84
|
|
|
hypercube corners in dimension n. |
85
|
|
|
|
86
|
|
|
For example, when num_dimension = len(loc_shape) = num_bits = 3 |
87
|
|
|
values correspond to values at corners of following coordinates |
88
|
|
|
|
89
|
|
|
.. code-block:: python |
90
|
|
|
|
91
|
|
|
[[0, 0, 0], # even |
92
|
|
|
[0, 0, 1], # odd |
93
|
|
|
[0, 1, 0], # even |
94
|
|
|
[0, 1, 1], # odd |
95
|
|
|
[1, 0, 0], # even |
96
|
|
|
[1, 0, 1], # odd |
97
|
|
|
[1, 1, 0], # even |
98
|
|
|
[1, 1, 1]] # odd |
99
|
|
|
|
100
|
|
|
values[::2] correspond to the corners with last coordinate == 0 |
101
|
|
|
|
102
|
|
|
.. code-block:: python |
103
|
|
|
|
104
|
|
|
[[0, 0, 0], |
105
|
|
|
[0, 1, 0], |
106
|
|
|
[1, 0, 0], |
107
|
|
|
[1, 1, 0]] |
108
|
|
|
|
109
|
|
|
values[1::2] correspond to the corners with last coordinate == 1 |
110
|
|
|
|
111
|
|
|
.. code-block:: python |
112
|
|
|
|
113
|
|
|
[[0, 0, 1], |
114
|
|
|
[0, 1, 1], |
115
|
|
|
[1, 0, 1], |
116
|
|
|
[1, 1, 1]] |
117
|
|
|
|
118
|
|
|
The weights correspond to the floor corners. |
119
|
|
|
For example, when num_dimension = len(loc_shape) = num_bits = 3, |
120
|
|
|
weight_floor = [f1, f2, f3] (ignoring the batch dimension). |
121
|
|
|
weight_ceil = [c1, c2, c3] (ignoring the batch dimension). |
122
|
|
|
|
123
|
|
|
So for corner with coords (x, y, z), x, y, z's values are 0 or 1 |
124
|
|
|
|
125
|
|
|
- weight for x = f1 if x = 0 else c1 |
126
|
|
|
- weight for y = f2 if y = 0 else c2 |
127
|
|
|
- weight for z = f3 if z = 0 else c3 |
128
|
|
|
|
129
|
|
|
so the weight for (x, y, z) is |
130
|
|
|
|
131
|
|
|
.. code-block:: text |
132
|
|
|
|
133
|
|
|
W_xyz = ((1-x) * f1 + x * c1) |
134
|
|
|
* ((1-y) * f2 + y * c2) |
135
|
|
|
* ((1-z) * f3 + z * c3) |
136
|
|
|
|
137
|
|
|
Let |
138
|
|
|
|
139
|
|
|
.. code-block:: text |
140
|
|
|
|
141
|
|
|
W_xy = ((1-x) * f1 + x * c1) |
142
|
|
|
* ((1-y) * f2 + y * c2) |
143
|
|
|
|
144
|
|
|
Then |
145
|
|
|
|
146
|
|
|
- W_xy0 = W_xy * f3 |
147
|
|
|
- W_xy1 = W_xy * c3 |
148
|
|
|
|
149
|
|
|
Similar to W_xyz, denote V_xyz the value at (x, y, z), |
150
|
|
|
the final sum V equals |
151
|
|
|
|
152
|
|
|
.. code-block:: text |
153
|
|
|
|
154
|
|
|
sum over x,y,z (V_xyz * W_xyz) |
155
|
|
|
= sum over x,y (V_xy0 * W_xy0 + V_xy1 * W_xy1) |
156
|
|
|
= sum over x,y (V_xy0 * W_xy * f3 + V_xy1 * W_xy * c3) |
157
|
|
|
= sum over x,y (V_xy0 * W_xy) * f3 + sum over x,y (V_xy1 * W_xy) * c3 |
158
|
|
|
|
159
|
|
|
That's why we call this pyramid combination. |
160
|
|
|
It calculates the linear interpolation gradually, starting from |
161
|
|
|
the last dimension. |
162
|
|
|
The key is that the weight of each corner is the product of the weights |
163
|
|
|
along each dimension. |
164
|
|
|
|
165
|
|
|
:param values: a list having values on the corner, |
166
|
|
|
it has 2**n tensors of shape |
167
|
|
|
(\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, ch) |
168
|
|
|
the order is consistent with get_n_bits_combinations |
169
|
|
|
loc_shape is independent from n, aka num_dim |
170
|
|
|
:param weight_floor: a list having weights of floor points, |
171
|
|
|
it has n tensors of shape |
172
|
|
|
(\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1) |
173
|
|
|
:param weight_ceil: a list having weights of ceil points, |
174
|
|
|
it has n tensors of shape |
175
|
|
|
(\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1) |
176
|
|
|
:return: one tensor of the same shape as an element in values |
177
|
|
|
(\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1) |
178
|
|
|
""" |
179
|
|
|
v_shape = values[0].shape |
180
|
|
|
wf_shape = weight_floor[0].shape |
181
|
|
|
wc_shape = weight_ceil[0].shape |
182
|
|
|
if len(v_shape) != len(wf_shape) or len(v_shape) != len(wc_shape): |
183
|
|
|
raise ValueError( |
184
|
|
|
"In pyramid_combination, elements of " |
185
|
|
|
"values, weight_floor, and weight_ceil should have same dimension. " |
186
|
|
|
f"value shape = {v_shape}, " |
187
|
|
|
f"weight_floor = {wf_shape}, " |
188
|
|
|
f"weight_ceil = {wc_shape}." |
189
|
|
|
) |
190
|
|
|
if 2 ** len(weight_floor) != len(values): |
191
|
|
|
raise ValueError( |
192
|
|
|
"In pyramid_combination, " |
193
|
|
|
"num_dim = len(weight_floor), " |
194
|
|
|
"len(values) must be 2 ** num_dim, " |
195
|
|
|
f"But len(weight_floor) = {len(weight_floor)}, " |
196
|
|
|
f"len(values) = {len(values)}" |
197
|
|
|
) |
198
|
|
|
|
199
|
|
|
if len(weight_floor) == 1: # one dimension |
200
|
|
|
return values[0] * weight_floor[0] + values[1] * weight_ceil[0] |
201
|
|
|
# multi dimension |
202
|
|
|
values_floor = pyramid_combination( |
203
|
|
|
values=values[::2], |
204
|
|
|
weight_floor=weight_floor[:-1], |
205
|
|
|
weight_ceil=weight_ceil[:-1], |
206
|
|
|
) |
207
|
|
|
values_floor = values_floor * weight_floor[-1] |
208
|
|
|
values_ceil = pyramid_combination( |
209
|
|
|
values=values[1::2], |
210
|
|
|
weight_floor=weight_floor[:-1], |
211
|
|
|
weight_ceil=weight_ceil[:-1], |
212
|
|
|
) |
213
|
|
|
values_ceil = values_ceil * weight_ceil[-1] |
214
|
|
|
return values_floor + values_ceil |
215
|
|
|
|
216
|
|
|
|
217
|
|
|
def resample( |
|
|
|
|
218
|
|
|
vol: tf.Tensor, |
219
|
|
|
loc: tf.Tensor, |
220
|
|
|
interpolation: str = "linear", |
221
|
|
|
zero_boundary: bool = True, |
222
|
|
|
) -> tf.Tensor: |
223
|
|
|
r""" |
224
|
|
|
Sample the volume at given locations. |
225
|
|
|
|
226
|
|
|
Input has |
227
|
|
|
|
228
|
|
|
- volume, vol, of shape = (batch, v_dim 1, ..., v_dim n), |
229
|
|
|
or (batch, v_dim 1, ..., v_dim n, ch), |
230
|
|
|
where n is the dimension of volume, |
231
|
|
|
ch is the extra dimension as features. |
232
|
|
|
|
233
|
|
|
Denote vol_shape = (v_dim 1, ..., v_dim n) |
234
|
|
|
|
235
|
|
|
- location, loc, of shape = (batch, l_dim 1, ..., l_dim m, n), |
236
|
|
|
where m is the dimension of output. |
237
|
|
|
|
238
|
|
|
Denote loc_shape = (l_dim 1, ..., l_dim m) |
239
|
|
|
|
240
|
|
|
Reference: |
241
|
|
|
|
242
|
|
|
- neuron's interpn |
243
|
|
|
https://github.com/adalca/neurite/blob/legacy/neuron/utils.py |
244
|
|
|
|
245
|
|
|
Difference |
246
|
|
|
|
247
|
|
|
1. they dont have batch size |
248
|
|
|
2. they support more dimensions in vol |
249
|
|
|
|
250
|
|
|
TODO try not using stack as neuron claims it's slower |
251
|
|
|
|
252
|
|
|
:param vol: shape = (batch, \*vol_shape) or (batch, \*vol_shape, ch) |
253
|
|
|
with the last channel for features |
254
|
|
|
:param loc: shape = (batch, \*loc_shape, n) |
255
|
|
|
such that loc[b, l1, ..., lm, :] = [v1, ..., vn] is of shape (n,), |
256
|
|
|
which represents a point in vol, with coordinates (v1, ..., vn) |
257
|
|
|
:param interpolation: linear only, TODO support nearest |
258
|
|
|
:param zero_boundary: if true, values on or outside boundary will be zeros |
259
|
|
|
:return: shape = (batch, \*loc_shape) or (batch, \*loc_shape, ch) |
260
|
|
|
""" |
261
|
|
|
|
262
|
|
|
if interpolation != "linear": |
263
|
|
|
raise ValueError("resample supports only linear interpolation") |
264
|
|
|
|
265
|
|
|
# init |
266
|
|
|
batch_size = vol.shape[0] |
267
|
|
|
loc_shape = loc.shape[1:-1] |
268
|
|
|
dim_vol = loc.shape[-1] # dimension of vol, n |
269
|
|
|
if dim_vol == len(vol.shape) - 1: |
270
|
|
|
# vol.shape = (batch, *vol_shape) |
271
|
|
|
has_ch = False |
272
|
|
|
elif dim_vol == len(vol.shape) - 2: |
273
|
|
|
# vol.shape = (batch, *vol_shape, ch) |
274
|
|
|
has_ch = True |
275
|
|
|
else: |
276
|
|
|
raise ValueError( |
277
|
|
|
"vol shape inconsistent with loc " |
278
|
|
|
"vol.shape = {}, loc.shape = {}".format(vol.shape, loc.shape) |
279
|
|
|
) |
280
|
|
|
vol_shape = vol.shape[1 : dim_vol + 1] |
281
|
|
|
|
282
|
|
|
# get floor/ceil for loc and stack, then clip together |
283
|
|
|
# loc, loc_floor, loc_ceil are have shape (batch, *loc_shape, n) |
284
|
|
|
loc_ceil = tf.math.ceil(loc) |
285
|
|
|
loc_floor = loc_ceil - 1 |
286
|
|
|
# (batch, *loc_shape, n, 3) |
287
|
|
|
clipped = tf.stack([loc, loc_floor, loc_ceil], axis=-1) |
288
|
|
|
clip_value_max = tf.cast(vol_shape, dtype=clipped.dtype) - 1 # (n,) |
289
|
|
|
clipped_shape = [1] * (len(loc_shape) + 1) + [dim_vol, 1] |
290
|
|
|
clip_value_max = tf.reshape(clip_value_max, shape=clipped_shape) |
291
|
|
|
clipped = tf.clip_by_value(clipped, clip_value_min=0, clip_value_max=clip_value_max) |
292
|
|
|
|
293
|
|
|
# loc_floor_ceil has n sublists |
294
|
|
|
# each one corresponds to the floor and ceil coordinates for d-th dimension |
295
|
|
|
# each tensor is of shape (batch, *loc_shape), dtype int32 |
296
|
|
|
|
297
|
|
|
# weight_floor has n tensors |
298
|
|
|
# each tensor is the weight for the corner of floor coordinates |
299
|
|
|
# each tensor's shape is (batch, *loc_shape) if volume has no feature channel |
300
|
|
|
# (batch, *loc_shape, 1) if volume has feature channel |
301
|
|
|
loc_floor_ceil, weight_floor, weight_ceil = [], [], [] |
302
|
|
|
# using for loop is faster than using list comprehension |
303
|
|
|
for dim in range(dim_vol): |
304
|
|
|
# shape = (batch, *loc_shape) |
305
|
|
|
c_clipped = clipped[..., dim, 0] |
306
|
|
|
c_floor = clipped[..., dim, 1] |
307
|
|
|
c_ceil = clipped[..., dim, 2] |
308
|
|
|
w_floor = c_ceil - c_clipped # shape = (batch, *loc_shape) |
309
|
|
|
w_ceil = c_clipped - c_floor if zero_boundary else 1 - w_floor |
310
|
|
|
if has_ch: |
311
|
|
|
w_floor = tf.expand_dims(w_floor, -1) # shape = (batch, *loc_shape, 1) |
312
|
|
|
w_ceil = tf.expand_dims(w_ceil, -1) # shape = (batch, *loc_shape, 1) |
313
|
|
|
|
314
|
|
|
loc_floor_ceil.append([tf.cast(c_floor, tf.int32), tf.cast(c_ceil, tf.int32)]) |
315
|
|
|
weight_floor.append(w_floor) |
316
|
|
|
weight_ceil.append(w_ceil) |
317
|
|
|
|
318
|
|
|
# 2**n corners, each is a list of n binary values |
319
|
|
|
corner_indices = get_n_bits_combinations(num_bits=len(vol_shape)) |
320
|
|
|
|
321
|
|
|
# batch_coords[b, l1, ..., lm] = b |
322
|
|
|
# range(batch_size) on axis 0 and repeated on other axes |
323
|
|
|
# add batch coords manually is faster than using batch_dims in tf.gather_nd |
324
|
|
|
batch_coords = tf.tile( |
325
|
|
|
tf.reshape(tf.range(batch_size), [batch_size] + [1] * len(loc_shape)), |
326
|
|
|
[1] + loc_shape, |
327
|
|
|
) # shape = (batch, *loc_shape) |
328
|
|
|
|
329
|
|
|
# get vol values on n-dim hypercube corners |
330
|
|
|
# corner_values has 2 ** n elements |
331
|
|
|
# each of shape (batch, *loc_shape) or (batch, *loc_shape, ch) |
332
|
|
|
corner_values = [ |
333
|
|
|
tf.gather_nd( |
334
|
|
|
vol, # shape = (batch, *vol_shape) or (batch, *vol_shape, ch) |
335
|
|
|
tf.stack( |
336
|
|
|
[batch_coords] |
337
|
|
|
+ [loc_floor_ceil[axis][fc_idx] for axis, fc_idx in enumerate(c)], |
338
|
|
|
axis=-1, |
339
|
|
|
), # shape = (batch, *loc_shape, n+1) after stack |
340
|
|
|
) |
341
|
|
|
for c in corner_indices # c is list of len n |
342
|
|
|
] # each tensor has shape (batch, *loc_shape) or (batch, *loc_shape, ch) |
343
|
|
|
|
344
|
|
|
# resample |
345
|
|
|
sampled = pyramid_combination( |
346
|
|
|
values=corner_values, weight_floor=weight_floor, weight_ceil=weight_ceil |
347
|
|
|
) |
348
|
|
|
return sampled |
349
|
|
|
|
350
|
|
|
|
351
|
|
|
def warp_grid(grid: tf.Tensor, theta: tf.Tensor) -> tf.Tensor: |
352
|
|
|
""" |
353
|
|
|
Perform transformation on the grid. |
354
|
|
|
|
355
|
|
|
- grid_padded[i,j,k,:] = [i j k 1] |
356
|
|
|
- grid_warped[b,i,j,k,p] = sum_over_q (grid_padded[i,j,k,q] * theta[b,q,p]) |
357
|
|
|
|
358
|
|
|
:param grid: shape = (dim1, dim2, dim3, 3), grid[i,j,k,:] = [i j k] |
359
|
|
|
:param theta: parameters of transformation, shape = (batch, 4, 3) |
360
|
|
|
:return: shape = (batch, dim1, dim2, dim3, 3) |
361
|
|
|
""" |
362
|
|
|
|
363
|
|
|
# grid_padded[i,j,k,:] = [i j k 1], shape = (dim1, dim2, dim3, 4) |
364
|
|
|
grid_padded = tf.concat([grid, tf.ones_like(grid[..., :1])], axis=3) |
365
|
|
|
|
366
|
|
|
# grid_warped[b,i,j,k,p] = sum_over_q (grid_padded[i,j,k,q] * theta[b,q,p]) |
367
|
|
|
# shape = (batch, dim1, dim2, dim3, 3) |
368
|
|
|
grid_warped = tf.einsum("ijkq,bqp->bijkp", grid_padded, theta) |
369
|
|
|
return grid_warped |
370
|
|
|
|
371
|
|
|
|
372
|
|
|
def gaussian_filter_3d(kernel_sigma: Union[Tuple, List]) -> tf.Tensor: |
373
|
|
|
""" |
374
|
|
|
Define a gaussian filter in 3d for smoothing. |
375
|
|
|
|
376
|
|
|
The filter size is defined 3*kernel_sigma |
377
|
|
|
|
378
|
|
|
|
379
|
|
|
:param kernel_sigma: the deviation at each direction (list) |
380
|
|
|
or use an isotropic deviation (int) |
381
|
|
|
:return: kernel: tf.Tensor specify a gaussian kernel of shape: |
382
|
|
|
[3*k for k in kernel_sigma] |
383
|
|
|
""" |
384
|
|
|
if isinstance(kernel_sigma, (int, float)): |
385
|
|
|
kernel_sigma = (kernel_sigma, kernel_sigma, kernel_sigma) |
386
|
|
|
|
387
|
|
|
kernel_size = [ |
388
|
|
|
int(np.ceil(ks * 3) + np.mod(np.ceil(ks * 3) + 1, 2)) for ks in kernel_sigma |
389
|
|
|
] |
390
|
|
|
|
391
|
|
|
# Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) |
392
|
|
|
coord = [np.arange(ks) for ks in kernel_size] |
393
|
|
|
|
394
|
|
|
xx, yy, zz = np.meshgrid(coord[0], coord[1], coord[2], indexing="ij") |
395
|
|
|
xyz_grid = np.concatenate( |
396
|
|
|
(xx[np.newaxis], yy[np.newaxis], zz[np.newaxis]), axis=0 |
397
|
|
|
) # 2, y, x |
398
|
|
|
|
399
|
|
|
mean = np.asarray([(ks - 1) / 2.0 for ks in kernel_size]) |
400
|
|
|
mean = mean.reshape(-1, 1, 1, 1) |
401
|
|
|
variance = np.asarray([ks ** 2.0 for ks in kernel_sigma]) |
402
|
|
|
variance = variance.reshape(-1, 1, 1, 1) |
403
|
|
|
|
404
|
|
|
# Calculate the 2-dimensional gaussian kernel which is |
405
|
|
|
# the product of two gaussian distributions for two different |
406
|
|
|
# variables (in this case called x and y) |
407
|
|
|
# 2.506628274631 = sqrt(2 * pi) |
408
|
|
|
|
409
|
|
|
norm_kernel = 1.0 / (np.sqrt(2 * np.pi) ** 3 + np.prod(kernel_sigma)) |
410
|
|
|
kernel = norm_kernel * np.exp( |
411
|
|
|
-np.sum((xyz_grid - mean) ** 2.0 / (2 * variance), axis=0) |
412
|
|
|
) |
413
|
|
|
|
414
|
|
|
# Make sure sum of values in gaussian kernel equals 1. |
415
|
|
|
kernel = kernel / np.sum(kernel) |
416
|
|
|
|
417
|
|
|
# Reshape |
418
|
|
|
kernel = kernel.reshape(kernel_size[0], kernel_size[1], kernel_size[2]) |
419
|
|
|
|
420
|
|
|
# Total kernel |
421
|
|
|
total_kernel = np.zeros(tuple(kernel_size) + (3, 3)) |
422
|
|
|
total_kernel[..., 0, 0] = kernel |
423
|
|
|
total_kernel[..., 1, 1] = kernel |
424
|
|
|
total_kernel[..., 2, 2] = kernel |
425
|
|
|
|
426
|
|
|
return tf.convert_to_tensor(total_kernel, dtype=tf.float32) |
427
|
|
|
|
428
|
|
|
|
429
|
|
|
def _deconv_output_padding( |
|
|
|
|
430
|
|
|
input_shape: int, output_shape: int, kernel_size: int, stride: int, padding: str |
431
|
|
|
) -> int: |
432
|
|
|
""" |
433
|
|
|
Calculate output padding for Conv3DTranspose in 1D. |
434
|
|
|
|
435
|
|
|
- output_shape = (input_shape - 1)*stride + kernel_size - 2*pad + output_padding |
436
|
|
|
- output_padding = output_shape - ((input_shape - 1)*stride + kernel_size - 2*pad) |
437
|
|
|
|
438
|
|
|
Reference: |
439
|
|
|
|
440
|
|
|
- https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/python/keras/utils/conv_utils.py#L140 |
441
|
|
|
|
442
|
|
|
:param input_shape: shape of Conv3DTranspose input tensor |
443
|
|
|
:param output_shape: shape of Conv3DTranspose output tensor |
444
|
|
|
:param kernel_size: kernel size of Conv3DTranspose layer |
445
|
|
|
:param stride: stride of Conv3DTranspose layer |
446
|
|
|
:param padding: padding of Conv3DTranspose layer |
447
|
|
|
:return: output_padding for Conv3DTranspose layer |
448
|
|
|
""" |
449
|
|
|
if padding == "same": |
450
|
|
|
pad = kernel_size // 2 |
451
|
|
|
elif padding == "valid": |
452
|
|
|
pad = 0 |
453
|
|
|
elif padding == "full": |
454
|
|
|
pad = kernel_size - 1 |
455
|
|
|
else: |
456
|
|
|
raise ValueError(f"Unknown padding {padding} in deconv_output_padding") |
457
|
|
|
return output_shape - ((input_shape - 1) * stride + kernel_size - 2 * pad) |
458
|
|
|
|
459
|
|
|
|
460
|
|
|
def deconv_output_padding( |
461
|
|
|
input_shape: Union[Tuple[int, ...], int], |
462
|
|
|
output_shape: Union[Tuple[int, ...], int], |
463
|
|
|
kernel_size: Union[Tuple[int, ...], int], |
464
|
|
|
stride: Union[Tuple[int, ...], int], |
465
|
|
|
padding: str, |
466
|
|
|
) -> Union[Tuple[int, ...], int]: |
467
|
|
|
""" |
468
|
|
|
Calculate output padding for Conv3DTranspose in any dimension. |
469
|
|
|
|
470
|
|
|
:param input_shape: shape of Conv3DTranspose input tensor, without batch or channel |
471
|
|
|
:param output_shape: shape of Conv3DTranspose output tensor, |
472
|
|
|
without batch or channel |
473
|
|
|
:param kernel_size: kernel size of Conv3DTranspose layer |
474
|
|
|
:param stride: stride of Conv3DTranspose layer |
475
|
|
|
:param padding: padding of Conv3DTranspose layer |
476
|
|
|
:return: output_padding for Conv3DTranspose layer |
477
|
|
|
""" |
478
|
|
|
if isinstance(input_shape, int): |
479
|
|
|
input_shape = (input_shape,) |
480
|
|
|
dim = len(input_shape) |
481
|
|
|
if isinstance(output_shape, int): |
482
|
|
|
output_shape = (output_shape,) |
483
|
|
|
if isinstance(kernel_size, int): |
484
|
|
|
kernel_size = (kernel_size,) * dim |
485
|
|
|
if isinstance(stride, int): |
486
|
|
|
stride = (stride,) * dim |
487
|
|
|
output_padding = tuple( |
488
|
|
|
_deconv_output_padding( |
489
|
|
|
input_shape=input_shape[d], |
490
|
|
|
output_shape=output_shape[d], |
491
|
|
|
kernel_size=kernel_size[d], |
492
|
|
|
stride=stride[d], |
493
|
|
|
padding=padding, |
494
|
|
|
) |
495
|
|
|
for d in range(dim) |
496
|
|
|
) |
497
|
|
|
if dim == 1: |
498
|
|
|
return output_padding[0] |
499
|
|
|
return output_padding |
500
|
|
|
|