Completed
Push — main ( 0c57ec...f6b5bf )
by Yunguan
18s queued 13s
created

deepreg.model.layer_util.deconv_output_padding()   B

Complexity

Conditions 6

Size

Total Lines 40
Code Lines 26

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 26
dl 0
loc 40
rs 8.3226
c 0
b 0
f 0
cc 6
nop 5
1
"""
2
Module containing utilities for layer inputs
3
"""
4
import itertools
5
from typing import List, Tuple, Union
6
7
import numpy as np
8
import tensorflow as tf
9
10
11
def get_reference_grid(grid_size: Union[Tuple[int, ...], List[int]]) -> tf.Tensor:
12
    """
13
    Generate a 3D grid with given size.
14
15
    Reference:
16
17
    - volshape_to_meshgrid of neuron
18
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
19
20
      neuron modifies meshgrid to make it faster, however local
21
      benchmark suggests tf.meshgrid is better
22
23
    Note:
24
25
    for tf.meshgrid, in the 3-D case with inputs of length M, N and P,
26
    outputs are of shape (N, M, P) for ‘xy’ indexing and
27
    (M, N, P) for ‘ij’ indexing.
28
29
    :param grid_size: list or tuple of size 3, [dim1, dim2, dim3]
30
    :return: shape = (dim1, dim2, dim3, 3),
31
             grid[i, j, k, :] = [i j k]
32
    """
33
34
    # dim1, dim2, dim3 = grid_size
35
    # mesh_grid has three elements, corresponding to i, j, k
36
    # for i in range(dim1)
37
    #     for j in range(dim2)
38
    #         for k in range(dim3)
39
    #             mesh_grid[0][i,j,k] = i
40
    #             mesh_grid[1][i,j,k] = j
41
    #             mesh_grid[2][i,j,k] = k
42
    mesh_grid = tf.meshgrid(
43
        tf.range(grid_size[0]),
44
        tf.range(grid_size[1]),
45
        tf.range(grid_size[2]),
46
        indexing="ij",
47
    )  # has three elements, each shape = (dim1, dim2, dim3)
48
    grid = tf.stack(mesh_grid, axis=3)  # shape = (dim1, dim2, dim3, 3)
49
    grid = tf.cast(grid, dtype=tf.float32)
50
    return grid
51
52
53
def get_n_bits_combinations(num_bits: int) -> List[List[int]]:
54
    """
55
    Function returning list containing all combinations of n bits.
56
    Given num_bits binary bits, each bit has value 0 or 1,
57
    there are in total 2**n_bits combinations.
58
59
    :param num_bits: int, number of combinations to evaluate
60
    :return: a list of length 2**n_bits,
61
      return[i] is the binary representation of the decimal integer.
62
63
    :Example:
64
        >>> from deepreg.model.layer_util import get_n_bits_combinations
65
        >>> get_n_bits_combinations(3)
66
        [[0, 0, 0], # 0
67
         [0, 0, 1], # 1
68
         [0, 1, 0], # 2
69
         [0, 1, 1], # 3
70
         [1, 0, 0], # 4
71
         [1, 0, 1], # 5
72
         [1, 1, 0], # 6
73
         [1, 1, 1]] # 7
74
    """
75
    assert num_bits >= 1
76
    return [list(i) for i in itertools.product([0, 1], repeat=num_bits)]
77
78
79
def pyramid_combination(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
80
    values: list, weight_floor: list, weight_ceil: list
81
) -> tf.Tensor:
82
    r"""
83
    Calculates linear interpolation (a weighted sum) using values of
84
    hypercube corners in dimension n.
85
86
    For example, when num_dimension = len(loc_shape) = num_bits = 3
87
    values correspond to values at corners of following coordinates
88
89
    .. code-block:: python
90
91
        [[0, 0, 0], # even
92
         [0, 0, 1], # odd
93
         [0, 1, 0], # even
94
         [0, 1, 1], # odd
95
         [1, 0, 0], # even
96
         [1, 0, 1], # odd
97
         [1, 1, 0], # even
98
         [1, 1, 1]] # odd
99
100
    values[::2] correspond to the corners with last coordinate == 0
101
102
    .. code-block:: python
103
104
        [[0, 0, 0],
105
         [0, 1, 0],
106
         [1, 0, 0],
107
         [1, 1, 0]]
108
109
    values[1::2] correspond to the corners with last coordinate == 1
110
111
    .. code-block:: python
112
113
        [[0, 0, 1],
114
         [0, 1, 1],
115
         [1, 0, 1],
116
         [1, 1, 1]]
117
118
    The weights correspond to the floor corners.
119
    For example, when num_dimension = len(loc_shape) = num_bits = 3,
120
    weight_floor = [f1, f2, f3] (ignoring the batch dimension).
121
    weight_ceil = [c1, c2, c3] (ignoring the batch dimension).
122
123
    So for corner with coords (x, y, z), x, y, z's values are 0 or 1
124
125
    - weight for x = f1 if x = 0 else c1
126
    - weight for y = f2 if y = 0 else c2
127
    - weight for z = f3 if z = 0 else c3
128
129
    so the weight for (x, y, z) is
130
131
    .. code-block:: text
132
133
        W_xyz = ((1-x) * f1 + x * c1)
134
              * ((1-y) * f2 + y * c2)
135
              * ((1-z) * f3 + z * c3)
136
137
    Let
138
139
    .. code-block:: text
140
141
        W_xy = ((1-x) * f1 + x * c1)
142
             * ((1-y) * f2 + y * c2)
143
144
    Then
145
146
    - W_xy0 = W_xy * f3
147
    - W_xy1 = W_xy * c3
148
149
    Similar to W_xyz, denote V_xyz the value at (x, y, z),
150
    the final sum V equals
151
152
    .. code-block:: text
153
154
          sum over x,y,z (V_xyz * W_xyz)
155
        = sum over x,y (V_xy0 * W_xy0 + V_xy1 * W_xy1)
156
        = sum over x,y (V_xy0 * W_xy * f3 + V_xy1 * W_xy * c3)
157
        = sum over x,y (V_xy0 * W_xy) * f3 + sum over x,y (V_xy1 * W_xy) * c3
158
159
    That's why we call this pyramid combination.
160
    It calculates the linear interpolation gradually, starting from
161
    the last dimension.
162
    The key is that the weight of each corner is the product of the weights
163
    along each dimension.
164
165
    :param values: a list having values on the corner,
166
                   it has 2**n tensors of shape
167
                   (\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, ch)
168
                   the order is consistent with get_n_bits_combinations
169
                   loc_shape is independent from n, aka num_dim
170
    :param weight_floor: a list having weights of floor points,
171
                    it has n tensors of shape
172
                    (\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1)
173
    :param weight_ceil: a list having weights of ceil points,
174
                    it has n tensors of shape
175
                    (\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1)
176
    :return: one tensor of the same shape as an element in values
177
             (\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1)
178
    """
179
    v_shape = values[0].shape
180
    wf_shape = weight_floor[0].shape
181
    wc_shape = weight_ceil[0].shape
182
    if len(v_shape) != len(wf_shape) or len(v_shape) != len(wc_shape):
183
        raise ValueError(
184
            "In pyramid_combination, elements of "
185
            "values, weight_floor, and weight_ceil should have same dimension. "
186
            f"value shape = {v_shape}, "
187
            f"weight_floor = {wf_shape}, "
188
            f"weight_ceil = {wc_shape}."
189
        )
190
    if 2 ** len(weight_floor) != len(values):
191
        raise ValueError(
192
            "In pyramid_combination, "
193
            "num_dim = len(weight_floor), "
194
            "len(values) must be 2 ** num_dim, "
195
            f"But len(weight_floor) = {len(weight_floor)}, "
196
            f"len(values) = {len(values)}"
197
        )
198
199
    if len(weight_floor) == 1:  # one dimension
200
        return values[0] * weight_floor[0] + values[1] * weight_ceil[0]
201
    # multi dimension
202
    values_floor = pyramid_combination(
203
        values=values[::2],
204
        weight_floor=weight_floor[:-1],
205
        weight_ceil=weight_ceil[:-1],
206
    )
207
    values_floor = values_floor * weight_floor[-1]
208
    values_ceil = pyramid_combination(
209
        values=values[1::2],
210
        weight_floor=weight_floor[:-1],
211
        weight_ceil=weight_ceil[:-1],
212
    )
213
    values_ceil = values_ceil * weight_ceil[-1]
214
    return values_floor + values_ceil
215
216
217
def resample(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
218
    vol: tf.Tensor,
219
    loc: tf.Tensor,
220
    interpolation: str = "linear",
221
    zero_boundary: bool = True,
222
) -> tf.Tensor:
223
    r"""
224
    Sample the volume at given locations.
225
226
    Input has
227
228
    - volume, vol, of shape = (batch, v_dim 1, ..., v_dim n),
229
      or (batch, v_dim 1, ..., v_dim n, ch),
230
      where n is the dimension of volume,
231
      ch is the extra dimension as features.
232
233
      Denote vol_shape = (v_dim 1, ..., v_dim n)
234
235
    - location, loc, of shape = (batch, l_dim 1, ..., l_dim m, n),
236
      where m is the dimension of output.
237
238
      Denote loc_shape = (l_dim 1, ..., l_dim m)
239
240
    Reference:
241
242
    - neuron's interpn
243
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
244
245
      Difference
246
247
      1. they dont have batch size
248
      2. they support more dimensions in vol
249
250
      TODO try not using stack as neuron claims it's slower
251
252
    :param vol: shape = (batch, \*vol_shape) or (batch, \*vol_shape, ch)
253
      with the last channel for features
254
    :param loc: shape = (batch, \*loc_shape, n)
255
      such that loc[b, l1, ..., lm, :] = [v1, ..., vn] is of shape (n,),
256
      which represents a point in vol, with coordinates (v1, ..., vn)
257
    :param interpolation: linear only, TODO support nearest
258
    :param zero_boundary: if true, values on or outside boundary will be zeros
259
    :return: shape = (batch, \*loc_shape) or (batch, \*loc_shape, ch)
260
    """
261
262
    if interpolation != "linear":
263
        raise ValueError("resample supports only linear interpolation")
264
265
    # init
266
    batch_size = vol.shape[0]
267
    loc_shape = loc.shape[1:-1]
268
    dim_vol = loc.shape[-1]  # dimension of vol, n
269
    if dim_vol == len(vol.shape) - 1:
270
        # vol.shape = (batch, *vol_shape)
271
        has_ch = False
272
    elif dim_vol == len(vol.shape) - 2:
273
        # vol.shape = (batch, *vol_shape, ch)
274
        has_ch = True
275
    else:
276
        raise ValueError(
277
            "vol shape inconsistent with loc "
278
            "vol.shape = {}, loc.shape = {}".format(vol.shape, loc.shape)
279
        )
280
    vol_shape = vol.shape[1 : dim_vol + 1]
281
282
    # get floor/ceil for loc and stack, then clip together
283
    # loc, loc_floor, loc_ceil are have shape (batch, *loc_shape, n)
284
    loc_ceil = tf.math.ceil(loc)
285
    loc_floor = loc_ceil - 1
286
    # (batch, *loc_shape, n, 3)
287
    clipped = tf.stack([loc, loc_floor, loc_ceil], axis=-1)
288
    clip_value_max = tf.cast(vol_shape, dtype=clipped.dtype) - 1  # (n,)
289
    clipped_shape = [1] * (len(loc_shape) + 1) + [dim_vol, 1]
290
    clip_value_max = tf.reshape(clip_value_max, shape=clipped_shape)
291
    clipped = tf.clip_by_value(clipped, clip_value_min=0, clip_value_max=clip_value_max)
292
293
    # loc_floor_ceil has n sublists
294
    # each one corresponds to the floor and ceil coordinates for d-th dimension
295
    # each tensor is of shape (batch, *loc_shape), dtype int32
296
297
    # weight_floor has n tensors
298
    # each tensor is the weight for the corner of floor coordinates
299
    # each tensor's shape is (batch, *loc_shape) if volume has no feature channel
300
    #                        (batch, *loc_shape, 1) if volume has feature channel
301
    loc_floor_ceil, weight_floor, weight_ceil = [], [], []
302
    # using for loop is faster than using list comprehension
303
    for dim in range(dim_vol):
304
        # shape = (batch, *loc_shape)
305
        c_clipped = clipped[..., dim, 0]
306
        c_floor = clipped[..., dim, 1]
307
        c_ceil = clipped[..., dim, 2]
308
        w_floor = c_ceil - c_clipped  # shape = (batch, *loc_shape)
309
        w_ceil = c_clipped - c_floor if zero_boundary else 1 - w_floor
310
        if has_ch:
311
            w_floor = tf.expand_dims(w_floor, -1)  # shape = (batch, *loc_shape, 1)
312
            w_ceil = tf.expand_dims(w_ceil, -1)  # shape = (batch, *loc_shape, 1)
313
314
        loc_floor_ceil.append([tf.cast(c_floor, tf.int32), tf.cast(c_ceil, tf.int32)])
315
        weight_floor.append(w_floor)
316
        weight_ceil.append(w_ceil)
317
318
    # 2**n corners, each is a list of n binary values
319
    corner_indices = get_n_bits_combinations(num_bits=len(vol_shape))
320
321
    # batch_coords[b, l1, ..., lm] = b
322
    # range(batch_size) on axis 0 and repeated on other axes
323
    # add batch coords manually is faster than using batch_dims in tf.gather_nd
324
    batch_coords = tf.tile(
325
        tf.reshape(tf.range(batch_size), [batch_size] + [1] * len(loc_shape)),
326
        [1] + loc_shape,
327
    )  # shape = (batch, *loc_shape)
328
329
    # get vol values on n-dim hypercube corners
330
    # corner_values has 2 ** n elements
331
    # each of shape (batch, *loc_shape) or (batch, *loc_shape, ch)
332
    corner_values = [
333
        tf.gather_nd(
334
            vol,  # shape = (batch, *vol_shape) or (batch, *vol_shape, ch)
335
            tf.stack(
336
                [batch_coords]
337
                + [loc_floor_ceil[axis][fc_idx] for axis, fc_idx in enumerate(c)],
338
                axis=-1,
339
            ),  # shape = (batch, *loc_shape, n+1) after stack
340
        )
341
        for c in corner_indices  # c is list of len n
342
    ]  # each tensor has shape (batch, *loc_shape) or (batch, *loc_shape, ch)
343
344
    # resample
345
    sampled = pyramid_combination(
346
        values=corner_values, weight_floor=weight_floor, weight_ceil=weight_ceil
347
    )
348
    return sampled
349
350
351
def warp_grid(grid: tf.Tensor, theta: tf.Tensor) -> tf.Tensor:
352
    """
353
    Perform transformation on the grid.
354
355
    - grid_padded[i,j,k,:] = [i j k 1]
356
    - grid_warped[b,i,j,k,p] = sum_over_q (grid_padded[i,j,k,q] * theta[b,q,p])
357
358
    :param grid: shape = (dim1, dim2, dim3, 3), grid[i,j,k,:] = [i j k]
359
    :param theta: parameters of transformation, shape = (batch, 4, 3)
360
    :return: shape = (batch, dim1, dim2, dim3, 3)
361
    """
362
363
    # grid_padded[i,j,k,:] = [i j k 1], shape = (dim1, dim2, dim3, 4)
364
    grid_padded = tf.concat([grid, tf.ones_like(grid[..., :1])], axis=3)
365
366
    # grid_warped[b,i,j,k,p] = sum_over_q (grid_padded[i,j,k,q] * theta[b,q,p])
367
    # shape = (batch, dim1, dim2, dim3, 3)
368
    grid_warped = tf.einsum("ijkq,bqp->bijkp", grid_padded, theta)
369
    return grid_warped
370
371
372
def gaussian_filter_3d(kernel_sigma: Union[Tuple, List]) -> tf.Tensor:
373
    """
374
    Define a gaussian filter in 3d for smoothing.
375
376
    The filter size is defined 3*kernel_sigma
377
378
379
    :param kernel_sigma: the deviation at each direction (list)
380
        or use an isotropic deviation (int)
381
    :return: kernel: tf.Tensor specify a gaussian kernel of shape:
382
        [3*k for k in kernel_sigma]
383
    """
384
    if isinstance(kernel_sigma, (int, float)):
385
        kernel_sigma = (kernel_sigma, kernel_sigma, kernel_sigma)
386
387
    kernel_size = [
388
        int(np.ceil(ks * 3) + np.mod(np.ceil(ks * 3) + 1, 2)) for ks in kernel_sigma
389
    ]
390
391
    # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
392
    coord = [np.arange(ks) for ks in kernel_size]
393
394
    xx, yy, zz = np.meshgrid(coord[0], coord[1], coord[2], indexing="ij")
395
    xyz_grid = np.concatenate(
396
        (xx[np.newaxis], yy[np.newaxis], zz[np.newaxis]), axis=0
397
    )  # 2, y, x
398
399
    mean = np.asarray([(ks - 1) / 2.0 for ks in kernel_size])
400
    mean = mean.reshape(-1, 1, 1, 1)
401
    variance = np.asarray([ks ** 2.0 for ks in kernel_sigma])
402
    variance = variance.reshape(-1, 1, 1, 1)
403
404
    # Calculate the 2-dimensional gaussian kernel which is
405
    # the product of two gaussian distributions for two different
406
    # variables (in this case called x and y)
407
    # 2.506628274631 = sqrt(2 * pi)
408
409
    norm_kernel = 1.0 / (np.sqrt(2 * np.pi) ** 3 + np.prod(kernel_sigma))
410
    kernel = norm_kernel * np.exp(
411
        -np.sum((xyz_grid - mean) ** 2.0 / (2 * variance), axis=0)
412
    )
413
414
    # Make sure sum of values in gaussian kernel equals 1.
415
    kernel = kernel / np.sum(kernel)
416
417
    # Reshape
418
    kernel = kernel.reshape(kernel_size[0], kernel_size[1], kernel_size[2])
419
420
    # Total kernel
421
    total_kernel = np.zeros(tuple(kernel_size) + (3, 3))
422
    total_kernel[..., 0, 0] = kernel
423
    total_kernel[..., 1, 1] = kernel
424
    total_kernel[..., 2, 2] = kernel
425
426
    return tf.convert_to_tensor(total_kernel, dtype=tf.float32)
427
428
429
def _deconv_output_padding(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
430
    input_shape: int, output_shape: int, kernel_size: int, stride: int, padding: str
431
) -> int:
432
    """
433
    Calculate output padding for Conv3DTranspose in 1D.
434
435
    - output_shape = (input_shape - 1)*stride + kernel_size - 2*pad + output_padding
436
    - output_padding = output_shape - ((input_shape - 1)*stride + kernel_size - 2*pad)
437
438
    Reference:
439
440
    - https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/python/keras/utils/conv_utils.py#L140
441
442
    :param input_shape: shape of Conv3DTranspose input tensor
443
    :param output_shape: shape of Conv3DTranspose output tensor
444
    :param kernel_size: kernel size of Conv3DTranspose layer
445
    :param stride: stride of Conv3DTranspose layer
446
    :param padding: padding of Conv3DTranspose layer
447
    :return: output_padding for Conv3DTranspose layer
448
    """
449
    if padding == "same":
450
        pad = kernel_size // 2
451
    elif padding == "valid":
452
        pad = 0
453
    elif padding == "full":
454
        pad = kernel_size - 1
455
    else:
456
        raise ValueError(f"Unknown padding {padding} in deconv_output_padding")
457
    return output_shape - ((input_shape - 1) * stride + kernel_size - 2 * pad)
458
459
460
def deconv_output_padding(
461
    input_shape: Union[Tuple[int, ...], int],
462
    output_shape: Union[Tuple[int, ...], int],
463
    kernel_size: Union[Tuple[int, ...], int],
464
    stride: Union[Tuple[int, ...], int],
465
    padding: str,
466
) -> Union[Tuple[int, ...], int]:
467
    """
468
    Calculate output padding for Conv3DTranspose in any dimension.
469
470
    :param input_shape: shape of Conv3DTranspose input tensor, without batch or channel
471
    :param output_shape: shape of Conv3DTranspose output tensor,
472
        without batch or channel
473
    :param kernel_size: kernel size of Conv3DTranspose layer
474
    :param stride: stride of Conv3DTranspose layer
475
    :param padding: padding of Conv3DTranspose layer
476
    :return: output_padding for Conv3DTranspose layer
477
    """
478
    if isinstance(input_shape, int):
479
        input_shape = (input_shape,)
480
    dim = len(input_shape)
481
    if isinstance(output_shape, int):
482
        output_shape = (output_shape,)
483
    if isinstance(kernel_size, int):
484
        kernel_size = (kernel_size,) * dim
485
    if isinstance(stride, int):
486
        stride = (stride,) * dim
487
    output_padding = tuple(
488
        _deconv_output_padding(
489
            input_shape=input_shape[d],
490
            output_shape=output_shape[d],
491
            kernel_size=kernel_size[d],
492
            stride=stride[d],
493
            padding=padding,
494
        )
495
        for d in range(dim)
496
    )
497
    if dim == 1:
498
        return output_padding[0]
499
    return output_padding
500