Passed
Pull Request — main (#656)
by Yunguan
02:44
created

deepreg.model.layer_util.warp_image_ddf()   C

Complexity

Conditions 9

Size

Total Lines 44
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 19
dl 0
loc 44
rs 6.6666
c 0
b 0
f 0
cc 9
nop 3
1
"""
2
Module containing utilities for layer inputs
3
"""
4
import itertools
5
6
import numpy as np
7
import tensorflow as tf
8
9
10
def get_reference_grid(grid_size: (tuple, list)) -> tf.Tensor:
11
    """
12
    Generate a 3D grid with given size.
13
14
    Reference:
15
16
    - volshape_to_meshgrid of neuron
17
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
18
19
      neuron modifies meshgrid to make it faster, however local
20
      benchmark suggests tf.meshgrid is better
21
22
    Note:
23
24
    for tf.meshgrid, in the 3-D case with inputs of length M, N and P,
25
    outputs are of shape (N, M, P) for ‘xy’ indexing and
26
    (M, N, P) for ‘ij’ indexing.
27
28
    :param grid_size: list or tuple of size 3, [dim1, dim2, dim3]
29
    :return: shape = (dim1, dim2, dim3, 3),
30
             grid[i, j, k, :] = [i j k]
31
    """
32
33
    # dim1, dim2, dim3 = grid_size
34
    # mesh_grid has three elements, corresponding to i, j, k
35
    # for i in range(dim1)
36
    #     for j in range(dim2)
37
    #         for k in range(dim3)
38
    #             mesh_grid[0][i,j,k] = i
39
    #             mesh_grid[1][i,j,k] = j
40
    #             mesh_grid[2][i,j,k] = k
41
    mesh_grid = tf.meshgrid(
42
        tf.range(grid_size[0]),
43
        tf.range(grid_size[1]),
44
        tf.range(grid_size[2]),
45
        indexing="ij",
46
    )  # has three elements, each shape = (dim1, dim2, dim3)
47
    grid = tf.stack(mesh_grid, axis=3)  # shape = (dim1, dim2, dim3, 3)
48
    grid = tf.cast(grid, dtype=tf.float32)
49
    return grid
50
51
52
def get_n_bits_combinations(num_bits: int) -> list:
53
    """
54
    Function returning list containing all combinations of n bits.
55
    Given num_bits binary bits, each bit has value 0 or 1,
56
    there are in total 2**n_bits combinations.
57
58
    :param num_bits: int, number of combinations to evaluate
59
    :return: a list of length 2**n_bits,
60
      return[i] is the binary representation of the decimal integer.
61
62
    :Example:
63
        >>> from deepreg.model.layer_util import get_n_bits_combinations
64
        >>> get_n_bits_combinations(3)
65
        [[0, 0, 0], # 0
66
         [0, 0, 1], # 1
67
         [0, 1, 0], # 2
68
         [0, 1, 1], # 3
69
         [1, 0, 0], # 4
70
         [1, 0, 1], # 5
71
         [1, 1, 0], # 6
72
         [1, 1, 1]] # 7
73
    """
74
    assert num_bits >= 1
75
    return [list(i) for i in itertools.product([0, 1], repeat=num_bits)]
76
77
78
def pyramid_combination(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
79
    values: list, weight_floor: list, weight_ceil: list
80
) -> tf.Tensor:
81
    r"""
82
    Calculates linear interpolation (a weighted sum) using values of
83
    hypercube corners in dimension n.
84
85
    For example, when num_dimension = len(loc_shape) = num_bits = 3
86
    values correspond to values at corners of following coordinates
87
88
    .. code-block:: python
89
90
        [[0, 0, 0], # even
91
         [0, 0, 1], # odd
92
         [0, 1, 0], # even
93
         [0, 1, 1], # odd
94
         [1, 0, 0], # even
95
         [1, 0, 1], # odd
96
         [1, 1, 0], # even
97
         [1, 1, 1]] # odd
98
99
    values[::2] correspond to the corners with last coordinate == 0
100
101
    .. code-block:: python
102
103
        [[0, 0, 0],
104
         [0, 1, 0],
105
         [1, 0, 0],
106
         [1, 1, 0]]
107
108
    values[1::2] correspond to the corners with last coordinate == 1
109
110
    .. code-block:: python
111
112
        [[0, 0, 1],
113
         [0, 1, 1],
114
         [1, 0, 1],
115
         [1, 1, 1]]
116
117
    The weights correspond to the floor corners.
118
    For example, when num_dimension = len(loc_shape) = num_bits = 3,
119
    weight_floor = [f1, f2, f3] (ignoring the batch dimension).
120
    weight_ceil = [c1, c2, c3] (ignoring the batch dimension).
121
122
    So for corner with coords (x, y, z), x, y, z's values are 0 or 1
123
124
    - weight for x = f1 if x = 0 else c1
125
    - weight for y = f2 if y = 0 else c2
126
    - weight for z = f3 if z = 0 else c3
127
128
    so the weight for (x, y, z) is
129
130
    .. code-block:: text
131
132
        W_xyz = ((1-x) * f1 + x * c1)
133
              * ((1-y) * f2 + y * c2)
134
              * ((1-z) * f3 + z * c3)
135
136
    Let
137
138
    .. code-block:: text
139
140
        W_xy = ((1-x) * f1 + x * c1)
141
             * ((1-y) * f2 + y * c2)
142
143
    Then
144
145
    - W_xy0 = W_xy * f3
146
    - W_xy1 = W_xy * c3
147
148
    Similar to W_xyz, denote V_xyz the value at (x, y, z),
149
    the final sum V equals
150
151
    .. code-block:: text
152
153
          sum over x,y,z (V_xyz * W_xyz)
154
        = sum over x,y (V_xy0 * W_xy0 + V_xy1 * W_xy1)
155
        = sum over x,y (V_xy0 * W_xy * f3 + V_xy1 * W_xy * c3)
156
        = sum over x,y (V_xy0 * W_xy) * f3 + sum over x,y (V_xy1 * W_xy) * c3
157
158
    That's why we call this pyramid combination.
159
    It calculates the linear interpolation gradually, starting from
160
    the last dimension.
161
    The key is that the weight of each corner is the product of the weights
162
    along each dimension.
163
164
    :param values: a list having values on the corner,
165
                   it has 2**n tensors of shape
166
                   (\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, ch)
167
                   the order is consistent with get_n_bits_combinations
168
                   loc_shape is independent from n, aka num_dim
169
    :param weight_floor: a list having weights of floor points,
170
                    it has n tensors of shape
171
                    (\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1)
172
    :param weight_ceil: a list having weights of ceil points,
173
                    it has n tensors of shape
174
                    (\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1)
175
    :return: one tensor of the same shape as an element in values
176
             (\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1)
177
    """
178
    v_shape = values[0].shape
179
    wf_shape = weight_floor[0].shape
180
    wc_shape = weight_ceil[0].shape
181
    if len(v_shape) != len(wf_shape) or len(v_shape) != len(wc_shape):
182
        raise ValueError(
183
            "In pyramid_combination, elements of "
184
            "values, weight_floor, and weight_ceil should have same dimension. "
185
            f"value shape = {v_shape}, "
186
            f"weight_floor = {wf_shape}, "
187
            f"weight_ceil = {wc_shape}."
188
        )
189
    if 2 ** len(weight_floor) != len(values):
190
        raise ValueError(
191
            "In pyramid_combination, "
192
            "num_dim = len(weight_floor), "
193
            "len(values) must be 2 ** num_dim, "
194
            f"But len(weight_floor) = {len(weight_floor)}, "
195
            f"len(values) = {len(values)}"
196
        )
197
198
    if len(weight_floor) == 1:  # one dimension
199
        return values[0] * weight_floor[0] + values[1] * weight_ceil[0]
200
    # multi dimension
201
    values_floor = pyramid_combination(
202
        values=values[::2],
203
        weight_floor=weight_floor[:-1],
204
        weight_ceil=weight_ceil[:-1],
205
    )
206
    values_floor = values_floor * weight_floor[-1]
207
    values_ceil = pyramid_combination(
208
        values=values[1::2],
209
        weight_floor=weight_floor[:-1],
210
        weight_ceil=weight_ceil[:-1],
211
    )
212
    values_ceil = values_ceil * weight_ceil[-1]
213
    return values_floor + values_ceil
214
215
216
def resample(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
217
    vol: tf.Tensor,
218
    loc: tf.Tensor,
219
    interpolation: str = "linear",
220
    zero_boundary: bool = True,
221
) -> tf.Tensor:
222
    r"""
223
    Sample the volume at given locations.
224
225
    Input has
226
227
    - volume, vol, of shape = (batch, v_dim 1, ..., v_dim n),
228
      or (batch, v_dim 1, ..., v_dim n, ch),
229
      where n is the dimension of volume,
230
      ch is the extra dimension as features.
231
232
      Denote vol_shape = (v_dim 1, ..., v_dim n)
233
234
    - location, loc, of shape = (batch, l_dim 1, ..., l_dim m, n),
235
      where m is the dimension of output.
236
237
      Denote loc_shape = (l_dim 1, ..., l_dim m)
238
239
    Reference:
240
241
    - neuron's interpn
242
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
243
244
      Difference
245
246
      1. they dont have batch size
247
      2. they support more dimensions in vol
248
249
      TODO try not using stack as neuron claims it's slower
250
251
    :param vol: shape = (batch, \*vol_shape) or (batch, \*vol_shape, ch)
252
      with the last channel for features
253
    :param loc: shape = (batch, \*loc_shape, n)
254
      such that loc[b, l1, ..., lm, :] = [v1, ..., vn] is of shape (n,),
255
      which represents a point in vol, with coordinates (v1, ..., vn)
256
    :param interpolation: linear only, TODO support nearest
257
    :param zero_boundary: if true, values on or outside boundary will be zeros
258
    :return: shape = (batch, \*loc_shape) or (batch, \*loc_shape, ch)
259
    """
260
261
    if interpolation != "linear":
262
        raise ValueError("resample supports only linear interpolation")
263
264
    # init
265
    batch_size = vol.shape[0]
266
    loc_shape = loc.shape[1:-1]
267
    dim_vol = loc.shape[-1]  # dimension of vol, n
268
    if dim_vol == len(vol.shape) - 1:
269
        # vol.shape = (batch, *vol_shape)
270
        has_ch = False
271
    elif dim_vol == len(vol.shape) - 2:
272
        # vol.shape = (batch, *vol_shape, ch)
273
        has_ch = True
274
    else:
275
        raise ValueError(
276
            "vol shape inconsistent with loc "
277
            "vol.shape = {}, loc.shape = {}".format(vol.shape, loc.shape)
278
        )
279
    vol_shape = vol.shape[1 : dim_vol + 1]
280
281
    # get floor/ceil for loc and stack, then clip together
282
    # loc, loc_floor, loc_ceil are have shape (batch, *loc_shape, n)
283
    loc_ceil = tf.math.ceil(loc)
284
    loc_floor = loc_ceil - 1
285
    # (batch, *loc_shape, n, 3)
286
    clipped = tf.stack([loc, loc_floor, loc_ceil], axis=-1)
287
    clip_value_max = tf.cast(vol_shape, dtype=clipped.dtype) - 1  # (n,)
288
    clipped_shape = [1] * (len(loc_shape) + 1) + [dim_vol, 1]
289
    clip_value_max = tf.reshape(clip_value_max, shape=clipped_shape)
290
    clipped = tf.clip_by_value(clipped, clip_value_min=0, clip_value_max=clip_value_max)
291
292
    # loc_floor_ceil has n sublists
293
    # each one corresponds to the floor and ceil coordinates for d-th dimension
294
    # each tensor is of shape (batch, *loc_shape), dtype int32
295
296
    # weight_floor has n tensors
297
    # each tensor is the weight for the corner of floor coordinates
298
    # each tensor's shape is (batch, *loc_shape) if volume has no feature channel
299
    #                        (batch, *loc_shape, 1) if volume has feature channel
300
    loc_floor_ceil, weight_floor, weight_ceil = [], [], []
301
    # using for loop is faster than using list comprehension
302
    for dim in range(dim_vol):
303
        # shape = (batch, *loc_shape)
304
        c_clipped = clipped[..., dim, 0]
305
        c_floor = clipped[..., dim, 1]
306
        c_ceil = clipped[..., dim, 2]
307
        w_floor = c_ceil - c_clipped  # shape = (batch, *loc_shape)
308
        w_ceil = c_clipped - c_floor if zero_boundary else 1 - w_floor
309
        if has_ch:
310
            w_floor = tf.expand_dims(w_floor, -1)  # shape = (batch, *loc_shape, 1)
311
            w_ceil = tf.expand_dims(w_ceil, -1)  # shape = (batch, *loc_shape, 1)
312
313
        loc_floor_ceil.append([tf.cast(c_floor, tf.int32), tf.cast(c_ceil, tf.int32)])
314
        weight_floor.append(w_floor)
315
        weight_ceil.append(w_ceil)
316
317
    # 2**n corners, each is a list of n binary values
318
    corner_indices = get_n_bits_combinations(num_bits=len(vol_shape))
319
320
    # batch_coords[b, l1, ..., lm] = b
321
    # range(batch_size) on axis 0 and repeated on other axes
322
    # add batch coords manually is faster than using batch_dims in tf.gather_nd
323
    batch_coords = tf.tile(
324
        tf.reshape(tf.range(batch_size), [batch_size] + [1] * len(loc_shape)),
325
        [1] + loc_shape,
326
    )  # shape = (batch, *loc_shape)
327
328
    # get vol values on n-dim hypercube corners
329
    # corner_values has 2 ** n elements
330
    # each of shape (batch, *loc_shape) or (batch, *loc_shape, ch)
331
    corner_values = [
332
        tf.gather_nd(
333
            vol,  # shape = (batch, *vol_shape) or (batch, *vol_shape, ch)
334
            tf.stack(
335
                [batch_coords]
336
                + [loc_floor_ceil[axis][fc_idx] for axis, fc_idx in enumerate(c)],
337
                axis=-1,
338
            ),  # shape = (batch, *loc_shape, n+1) after stack
339
        )
340
        for c in corner_indices  # c is list of len n
341
    ]  # each tensor has shape (batch, *loc_shape) or (batch, *loc_shape, ch)
342
343
    # resample
344
    sampled = pyramid_combination(
345
        values=corner_values, weight_floor=weight_floor, weight_ceil=weight_ceil
346
    )
347
    return sampled
348
349
350
def gen_rand_affine_transform(
351
    batch_size: int, scale: float, seed: (int, None) = None
352
) -> tf.Tensor:
353
    """
354
    Function that generates a random 3D transformation parameters for a batch of data.
355
356
    for 3D coordinates, affine transformation is
357
358
    .. code-block:: text
359
360
        [[x' y' z' 1]] = [[x y z 1]] * [[* * * 0]
361
                                        [* * * 0]
362
                                        [* * * 0]
363
                                        [* * * 1]]
364
365
    where each * represents a degree of freedom,
366
    so there are in total 12 degrees of freedom
367
    the equation can be denoted as
368
369
        new = old * T
370
371
    where
372
373
    - new is the transformed coordinates, of shape (1, 4)
374
    - old is the original coordinates, of shape (1, 4)
375
    - T is the transformation matrix, of shape (4, 4)
376
377
    the equation can be simplified to
378
379
    .. code-block:: text
380
381
        [[x' y' z']] = [[x y z 1]] * [[* * *]
382
                                      [* * *]
383
                                      [* * *]
384
                                      [* * *]]
385
386
    so that
387
388
        new = old * T
389
390
    where
391
392
    - new is the transformed coordinates, of shape (1, 3)
393
    - old is the original coordinates, of shape (1, 4)
394
    - T is the transformation matrix, of shape (4, 3)
395
396
    Given original and transformed coordinates,
397
    we can calculate the transformation matrix using
398
399
        x = np.linalg.lstsq(a, b)
400
401
    such that
402
403
        a x = b
404
405
    In our case,
406
407
    - a = old
408
    - b = new
409
    - x = T
410
411
    To generate random transformation,
412
    we choose to add random perturbation to corner coordinates as follows:
413
    for corner of coordinates (x, y, z), the noise is
414
415
        -(x, y, z) .* (r1, r2, r3)
416
417
    where ri is a random number between (0, scale).
418
    So
419
420
        (x', y', z') = (x, y, z) .* (1-r1, 1-r2, 1-r3)
421
422
    Thus, we can directly sample between 1-scale and 1 instead
423
424
    We choose to calculate the transformation based on
425
    four corners in a cube centered at (0, 0, 0).
426
    A cube is shown as below, where
427
428
    - C = (-1, -1, -1)
429
    - G = (-1, -1, 1)
430
    - D = (-1, 1, -1)
431
    - A = (1, -1, -1)
432
433
    .. code-block:: text
434
435
                    G — — — — — — — — H
436
                  / |               / |
437
                /   |             /   |
438
              /     |           /     |
439
            /       |         /       |
440
          /         |       /         |
441
        E — — — — — — — — F           |
442
        |           |     |           |
443
        |           |     |           |
444
        |           C — — | — — — — — D
445
        |         /       |         /
446
        |       /         |       /
447
        |     /           |     /
448
        |   /             |   /
449
        | /               | /
450
        A — — — — — — — — B
451
452
    :param batch_size: int
453
    :param scale: a float number between 0 and 1
454
    :param seed: control the randomness
455
    :return: shape = (batch, 4, 3)
456
    """
457
458
    assert 0 <= scale <= 1
459
    np.random.seed(seed)
460
    noise = np.random.uniform(1 - scale, 1, [batch_size, 4, 3])  # shape = (batch, 4, 3)
461
462
    # old represents four corners of a cube
463
    # corresponding to the corner C G D A as shown above
464
    old = np.tile(
465
        [[[-1, -1, -1, 1], [-1, -1, 1, 1], [-1, 1, -1, 1], [1, -1, -1, 1]]],
466
        [batch_size, 1, 1],
467
    )  # shape = (batch, 4, 4)
468
    new = old[:, :, :3] * noise  # shape = (batch, 4, 3)
469
470
    theta = np.array(
471
        [np.linalg.lstsq(old[k], new[k], rcond=-1)[0] for k in range(batch_size)]
472
    )  # shape = (batch, 4, 3)
473
474
    return tf.cast(theta, dtype=tf.float32)
475
476
477
def gen_rand_ddf(
478
    batch_size: int,
479
    image_size: tuple,
480
    field_strength: (tuple, list),
481
    low_res_size: (tuple, list),
482
    seed: (int, None) = None,
483
) -> tf.Tensor:
484
    """
485
    Function that generates a random 3D DDF for a batch of data.
486
487
    :param batch_size:
488
    :param image_size:
489
    :param field_strength: maximum field strength, computed as a U[0,field_strength]
490
    :param low_res_size: low_resolution deformation field that will be upsampled to
491
        the original size in order to get smooth and more realistic fields.
492
    :param seed: control the randomness
493
    :return:
494
    """
495
496
    np.random.seed(seed)
497
    low_res_strength = np.random.uniform(0, field_strength, (batch_size, 1, 1, 1, 3))
498
    low_res_field = low_res_strength * np.random.randn(
499
        batch_size, low_res_size[0], low_res_size[1], low_res_size[2], 3
500
    )
501
    high_res_field = resize3d(low_res_field, image_size)
502
    return high_res_field
503
504
505
def warp_grid(grid: tf.Tensor, theta: tf.Tensor) -> tf.Tensor:
506
    """
507
    Perform transformation on the grid.
508
509
    - grid_padded[i,j,k,:] = [i j k 1]
510
    - grid_warped[b,i,j,k,p] = sum_over_q (grid_padded[i,j,k,q] * theta[b,q,p])
511
512
    :param grid: shape = (dim1, dim2, dim3, 3), grid[i,j,k,:] = [i j k]
513
    :param theta: parameters of transformation, shape = (batch, 4, 3)
514
    :return: shape = (batch, dim1, dim2, dim3, 3)
515
    """
516
517
    # grid_padded[i,j,k,:] = [i j k 1], shape = (dim1, dim2, dim3, 4)
518
    grid_padded = tf.concat([grid, tf.ones_like(grid[..., :1])], axis=3)
519
520
    # grid_warped[b,i,j,k,p] = sum_over_q (grid_padded[i,j,k,q] * theta[b,q,p])
521
    # shape = (batch, dim1, dim2, dim3, 3)
522
    grid_warped = tf.einsum("ijkq,bqp->bijkp", grid_padded, theta)
523
    return grid_warped
524
525
526
def resize3d(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
527
    image: tf.Tensor, size: (tuple, list), method: str = tf.image.ResizeMethod.BILINEAR
528
) -> tf.Tensor:
529
    """
530
    Tensorflow does not have resize 3d, therefore the resize is performed two folds.
531
532
    - resize dim2 and dim3
533
    - resize dim1 and dim2
534
535
    :param image: tensor of shape = (batch, dim1, dim2, dim3, channels)
536
                                 or (batch, dim1, dim2, dim3)
537
                                 or (dim1, dim2, dim3)
538
    :param size: tuple, (out_dim1, out_dim2, out_dim3)
539
    :param method: str, one of tf.image.ResizeMethod
540
    :return: tensor of shape = (batch, out_dim1, out_dim2, out_dim3, channels)
541
                            or (batch, dim1, dim2, dim3)
542
                            or (dim1, dim2, dim3)
543
    """
544
    # sanity check
545
    image_dim = len(image.shape)
546
    if image_dim not in [3, 4, 5]:
547
        raise ValueError(
548
            "resize3d takes input image of dimension 3 or 4 or 5,"
549
            "corresponding to (dim1, dim2, dim3) "
550
            "or (batch, dim1, dim2, dim3)"
551
            "or (batch, dim1, dim2, dim3, channels),"
552
            "got image shape{}".format(image.shape)
553
        )
554
    if len(size) != 3:
555
        raise ValueError("resize3d takes size of type tuple/list and of length 3")
556
557
    # init
558
    if image_dim == 5:
559
        has_channel = True
560
        has_batch = True
561
        input_image_shape = image.shape[1:4]
562
    elif image_dim == 4:
563
        has_channel = False
564
        has_batch = True
565
        input_image_shape = image.shape[1:4]
566
    else:
567
        has_channel = False
568
        has_batch = False
569
        input_image_shape = image.shape[0:3]
570
571
    # no need of resize
572
    if input_image_shape == tuple(size):
573
        return image
574
575
    # expand to five dimensions
576
    if not has_batch:
577
        image = tf.expand_dims(image, axis=0)
578
    if not has_channel:
579
        image = tf.expand_dims(image, axis=-1)
580
    assert len(image.shape) == 5  # (batch, dim1, dim2, dim3, channels)
581
    image_shape = tf.shape(image)
582
583
    # merge axis 0 and 1
584
    output = tf.reshape(
585
        image, (-1, image_shape[2], image_shape[3], image_shape[4])
586
    )  # (batch * dim1, dim2, dim3, channels)
587
588
    # resize dim2 and dim3
589
    output = tf.image.resize(
590
        images=output, size=size[1:], method=method
591
    )  # (batch * dim1, out_dim2, out_dim3, channels)
592
593
    # split axis 0 and merge axis 3 and 4
594
    output = tf.reshape(
595
        output, shape=(-1, image_shape[1], size[1], size[2] * image_shape[4])
596
    )  # (batch, dim1, out_dim2, out_dim3 * channels)
597
598
    # resize dim1 and dim2
599
    output = tf.image.resize(
600
        images=output, size=size[:2], method=method
601
    )  # (batch, out_dim1, out_dim2, out_dim3 * channels)
602
603
    # reshape
604
    output = tf.reshape(
605
        output, shape=[-1, *size, image_shape[4]]
606
    )  # (batch, out_dim1, out_dim2, out_dim3, channels)
607
608
    # squeeze to original dimension
609
    if not has_batch:
610
        output = tf.squeeze(output, axis=0)
611
    if not has_channel:
612
        output = tf.squeeze(output, axis=-1)
613
    return output
614
615
616
def gaussian_filter_3d(kernel_sigma: (list, tuple, int)) -> tf.Tensor:
617
    """
618
    Define a gaussian filter in 3d for smoothing.
619
620
    The filter size is defined 3*kernel_sigma
621
622
623
    :param kernel_sigma: the deviation at each direction (list)
624
        or use an isotropic deviation (int)
625
    :return: kernel: tf.Tensor specify a gaussian kernel of shape:
626
        [3*k for k in kernel_sigma]
627
    """
628
    if isinstance(kernel_sigma, int):
629
        kernel_sigma = (kernel_sigma, kernel_sigma, kernel_sigma)
630
631
    kernel_size = [
632
        int(np.ceil(ks * 3) + np.mod(np.ceil(ks * 3) + 1, 2)) for ks in kernel_sigma
633
    ]
634
635
    # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
636
    coord = [np.arange(ks) for ks in kernel_size]
637
638
    xx, yy, zz = np.meshgrid(coord[0], coord[1], coord[2], indexing="ij")
639
    xyz_grid = np.concatenate(
640
        (xx[np.newaxis], yy[np.newaxis], zz[np.newaxis]), axis=0
641
    )  # 2, y, x
642
643
    mean = np.asarray([(ks - 1) / 2.0 for ks in kernel_size])
644
    mean = mean.reshape(-1, 1, 1, 1)
645
    variance = np.asarray([ks ** 2.0 for ks in kernel_sigma])
646
    variance = variance.reshape(-1, 1, 1, 1)
647
648
    # Calculate the 2-dimensional gaussian kernel which is
649
    # the product of two gaussian distributions for two different
650
    # variables (in this case called x and y)
651
    # 2.506628274631 = sqrt(2 * pi)
652
653
    norm_kernel = 1.0 / (np.sqrt(2 * np.pi) ** 3 + np.prod(kernel_sigma))
654
    kernel = norm_kernel * np.exp(
655
        -np.sum((xyz_grid - mean) ** 2.0 / (2 * variance), axis=0)
656
    )
657
658
    # Make sure sum of values in gaussian kernel equals 1.
659
    kernel = kernel / np.sum(kernel)
660
661
    # Reshape
662
    kernel = kernel.reshape(kernel_size[0], kernel_size[1], kernel_size[2])
663
664
    # Total kernel
665
    total_kernel = np.zeros(tuple(kernel_size) + (3, 3))
666
    total_kernel[..., 0, 0] = kernel
667
    total_kernel[..., 1, 1] = kernel
668
    total_kernel[..., 2, 2] = kernel
669
670
    return tf.convert_to_tensor(total_kernel, dtype=tf.float32)
671