Passed
Pull Request — main (#656)
by Yunguan
03:05
created

deepreg.model.layer_util.warp_image_ddf()   C

Complexity

Conditions 9

Size

Total Lines 44
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 19
dl 0
loc 44
rs 6.6666
c 0
b 0
f 0
cc 9
nop 3
1
"""
2
Module containing utilities for layer inputs
3
"""
4
import itertools
5
6
import numpy as np
7
import tensorflow as tf
8
9
10
def get_reference_grid(grid_size: (tuple, list)) -> tf.Tensor:
11
    """
12
    Generate a 3D grid with given size.
13
14
    Reference:
15
16
    - volshape_to_meshgrid of neuron
17
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
18
19
      neuron modifies meshgrid to make it faster, however local
20
      benchmark suggests tf.meshgrid is better
21
22
    Note:
23
24
    for tf.meshgrid, in the 3-D case with inputs of length M, N and P,
25
    outputs are of shape (N, M, P) for ‘xy’ indexing and
26
    (M, N, P) for ‘ij’ indexing.
27
28
    :param grid_size: list or tuple of size 3, [dim1, dim2, dim3]
29
    :return: shape = (dim1, dim2, dim3, 3),
30
             grid[i, j, k, :] = [i j k]
31
    """
32
33
    # dim1, dim2, dim3 = grid_size
34
    # mesh_grid has three elements, corresponding to i, j, k
35
    # for i in range(dim1)
36
    #     for j in range(dim2)
37
    #         for k in range(dim3)
38
    #             mesh_grid[0][i,j,k] = i
39
    #             mesh_grid[1][i,j,k] = j
40
    #             mesh_grid[2][i,j,k] = k
41
    mesh_grid = tf.meshgrid(
42
        tf.range(grid_size[0]),
43
        tf.range(grid_size[1]),
44
        tf.range(grid_size[2]),
45
        indexing="ij",
46
    )  # has three elements, each shape = (dim1, dim2, dim3)
47
    grid = tf.stack(mesh_grid, axis=3)  # shape = (dim1, dim2, dim3, 3)
48
    grid = tf.cast(grid, dtype=tf.float32)
49
    return grid
50
51
52
def get_n_bits_combinations(num_bits: int) -> list:
53
    """
54
    Function returning list containing all combinations of n bits.
55
    Given num_bits binary bits, each bit has value 0 or 1,
56
    there are in total 2**n_bits combinations.
57
58
    :param num_bits: int, number of combinations to evaluate
59
    :return: a list of length 2**n_bits,
60
      return[i] is the binary representation of the decimal integer.
61
62
    :Example:
63
        >>> from deepreg.model.layer_util import get_n_bits_combinations
64
        >>> get_n_bits_combinations(3)
65
        [[0, 0, 0], # 0
66
         [0, 0, 1], # 1
67
         [0, 1, 0], # 2
68
         [0, 1, 1], # 3
69
         [1, 0, 0], # 4
70
         [1, 0, 1], # 5
71
         [1, 1, 0], # 6
72
         [1, 1, 1]] # 7
73
    """
74
    assert num_bits >= 1
75
    return [list(i) for i in itertools.product([0, 1], repeat=num_bits)]
76
77
78
def pyramid_combination(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
79
    values: list, weight_floor: list, weight_ceil: list
80
) -> tf.Tensor:
81
    r"""
82
    Calculates linear interpolation (a weighted sum) using values of
83
    hypercube corners in dimension n.
84
85
    For example, when num_dimension = len(loc_shape) = num_bits = 3
86
    values correspond to values at corners of following coordinates
87
88
    .. code-block:: python
89
90
        [[0, 0, 0], # even
91
         [0, 0, 1], # odd
92
         [0, 1, 0], # even
93
         [0, 1, 1], # odd
94
         [1, 0, 0], # even
95
         [1, 0, 1], # odd
96
         [1, 1, 0], # even
97
         [1, 1, 1]] # odd
98
99
    values[::2] correspond to the corners with last coordinate == 0
100
101
    .. code-block:: python
102
103
        [[0, 0, 0],
104
         [0, 1, 0],
105
         [1, 0, 0],
106
         [1, 1, 0]]
107
108
    values[1::2] correspond to the corners with last coordinate == 1
109
110
    .. code-block:: python
111
112
        [[0, 0, 1],
113
         [0, 1, 1],
114
         [1, 0, 1],
115
         [1, 1, 1]]
116
117
    The weights correspond to the floor corners.
118
    For example, when num_dimension = len(loc_shape) = num_bits = 3,
119
    weight_floor = [f1, f2, f3] (ignoring the batch dimension).
120
    weight_ceil = [c1, c2, c3] (ignoring the batch dimension).
121
122
    So for corner with coords (x, y, z), x, y, z's values are 0 or 1
123
124
    - weight for x = f1 if x = 0 else c1
125
    - weight for y = f2 if y = 0 else c2
126
    - weight for z = f3 if z = 0 else c3
127
128
    so the weight for (x, y, z) is
129
130
    .. code-block:: text
131
132
        W_xyz = ((1-x) * f1 + x * c1)
133
              * ((1-y) * f2 + y * c2)
134
              * ((1-z) * f3 + z * c3)
135
136
    Let
137
138
    .. code-block:: text
139
140
        W_xy = ((1-x) * f1 + x * c1)
141
             * ((1-y) * f2 + y * c2)
142
143
    Then
144
145
    - W_xy0 = W_xy * f3
146
    - W_xy1 = W_xy * c3
147
148
    Similar to W_xyz, denote V_xyz the value at (x, y, z),
149
    the final sum V equals
150
151
    .. code-block:: text
152
153
          sum over x,y,z (V_xyz * W_xyz)
154
        = sum over x,y (V_xy0 * W_xy0 + V_xy1 * W_xy1)
155
        = sum over x,y (V_xy0 * W_xy * f3 + V_xy1 * W_xy * c3)
156
        = sum over x,y (V_xy0 * W_xy) * f3 + sum over x,y (V_xy1 * W_xy) * c3
157
158
    That's why we call this pyramid combination.
159
    It calculates the linear interpolation gradually, starting from
160
    the last dimension.
161
    The key is that the weight of each corner is the product of the weights
162
    along each dimension.
163
164
    :param values: a list having values on the corner,
165
                   it has 2**n tensors of shape
166
                   (\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, ch)
167
                   the order is consistent with get_n_bits_combinations
168
                   loc_shape is independent from n, aka num_dim
169
    :param weight_floor: a list having weights of floor points,
170
                    it has n tensors of shape
171
                    (\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1)
172
    :param weight_ceil: a list having weights of ceil points,
173
                    it has n tensors of shape
174
                    (\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1)
175
    :return: one tensor of the same shape as an element in values
176
             (\*loc_shape) or (batch, \*loc_shape) or (batch, \*loc_shape, 1)
177
    """
178
    v_shape = values[0].shape
179
    wf_shape = weight_floor[0].shape
180
    wc_shape = weight_ceil[0].shape
181
    if len(v_shape) != len(wf_shape) or len(v_shape) != len(wc_shape):
182
        raise ValueError(
183
            "In pyramid_combination, elements of "
184
            "values, weight_floor, and weight_ceil should have same dimension. "
185
            f"value shape = {v_shape}, "
186
            f"weight_floor = {wf_shape}, "
187
            f"weight_ceil = {wc_shape}."
188
        )
189
    if 2 ** len(weight_floor) != len(values):
190
        raise ValueError(
191
            "In pyramid_combination, "
192
            "num_dim = len(weight_floor), "
193
            "len(values) must be 2 ** num_dim, "
194
            f"But len(weight_floor) = {len(weight_floor)}, "
195
            f"len(values) = {len(values)}"
196
        )
197
198
    if len(weight_floor) == 1:  # one dimension
199
        return values[0] * weight_floor[0] + values[1] * weight_ceil[0]
200
    # multi dimension
201
    values_floor = pyramid_combination(
202
        values=values[::2],
203
        weight_floor=weight_floor[:-1],
204
        weight_ceil=weight_ceil[:-1],
205
    )
206
    values_floor = values_floor * weight_floor[-1]
207
    values_ceil = pyramid_combination(
208
        values=values[1::2],
209
        weight_floor=weight_floor[:-1],
210
        weight_ceil=weight_ceil[:-1],
211
    )
212
    values_ceil = values_ceil * weight_ceil[-1]
213
    return values_floor + values_ceil
214
215
216
def resample(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
217
    vol: tf.Tensor,
218
    loc: tf.Tensor,
219
    interpolation: str = "linear",
220
    zero_boundary: bool = True,
221
) -> tf.Tensor:
222
    r"""
223
    Sample the volume at given locations.
224
225
    Input has
226
227
    - volume, vol, of shape = (batch, v_dim 1, ..., v_dim n),
228
      or (batch, v_dim 1, ..., v_dim n, ch),
229
      where n is the dimension of volume,
230
      ch is the extra dimension as features.
231
232
      Denote vol_shape = (v_dim 1, ..., v_dim n)
233
234
    - location, loc, of shape = (batch, l_dim 1, ..., l_dim m, n),
235
      where m is the dimension of output.
236
237
      Denote loc_shape = (l_dim 1, ..., l_dim m)
238
239
    Reference:
240
241
    - neuron's interpn
242
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
243
244
      Difference
245
246
      1. they dont have batch size
247
      2. they support more dimensions in vol
248
249
      TODO try not using stack as neuron claims it's slower
250
251
    :param vol: shape = (batch, \*vol_shape) or (batch, \*vol_shape, ch)
252
      with the last channel for features
253
    :param loc: shape = (batch, \*loc_shape, n)
254
      such that loc[b, l1, ..., lm, :] = [v1, ..., vn] is of shape (n,),
255
      which represents a point in vol, with coordinates (v1, ..., vn)
256
    :param interpolation: linear only, TODO support nearest
257
    :param zero_boundary: if true, values on or outside boundary will be zeros
258
    :return: shape = (batch, \*loc_shape) or (batch, \*loc_shape, ch)
259
    """
260
261
    if interpolation != "linear":
262
        raise ValueError("resample supports only linear interpolation")
263
264
    # init
265
    batch_size = vol.shape[0]
266
    loc_shape = loc.shape[1:-1]
267
    dim_vol = loc.shape[-1]  # dimension of vol, n
268
    if dim_vol == len(vol.shape) - 1:
269
        # vol.shape = (batch, *vol_shape)
270
        has_ch = False
271
    elif dim_vol == len(vol.shape) - 2:
272
        # vol.shape = (batch, *vol_shape, ch)
273
        has_ch = True
274
    else:
275
        raise ValueError(
276
            "vol shape inconsistent with loc "
277
            "vol.shape = {}, loc.shape = {}".format(vol.shape, loc.shape)
278
        )
279
    vol_shape = vol.shape[1 : dim_vol + 1]
280
281
    # get floor/ceil for loc and stack, then clip together
282
    # loc, loc_floor, loc_ceil are have shape (batch, *loc_shape, n)
283
    loc_ceil = tf.math.ceil(loc)
284
    loc_floor = loc_ceil - 1
285
    # (batch, *loc_shape, n, 3)
286
    clipped = tf.stack([loc, loc_floor, loc_ceil], axis=-1)
287
    clip_value_max = tf.cast(vol_shape, dtype=clipped.dtype) - 1  # (n,)
288
    clipped_shape = [1] * (len(loc_shape) + 1) + [dim_vol, 1]
289
    clip_value_max = tf.reshape(clip_value_max, shape=clipped_shape)
290
    clipped = tf.clip_by_value(clipped, clip_value_min=0, clip_value_max=clip_value_max)
291
292
    # loc_floor_ceil has n sublists
293
    # each one corresponds to the floor and ceil coordinates for d-th dimension
294
    # each tensor is of shape (batch, *loc_shape), dtype int32
295
296
    # weight_floor has n tensors
297
    # each tensor is the weight for the corner of floor coordinates
298
    # each tensor's shape is (batch, *loc_shape) if volume has no feature channel
299
    #                        (batch, *loc_shape, 1) if volume has feature channel
300
    loc_floor_ceil, weight_floor, weight_ceil = [], [], []
301
    # using for loop is faster than using list comprehension
302
    for dim in range(dim_vol):
303
        # shape = (batch, *loc_shape)
304
        c_clipped = clipped[..., dim, 0]
305
        c_floor = clipped[..., dim, 1]
306
        c_ceil = clipped[..., dim, 2]
307
        w_floor = c_ceil - c_clipped  # shape = (batch, *loc_shape)
308
        w_ceil = c_clipped - c_floor if zero_boundary else 1 - w_floor
309
        if has_ch:
310
            w_floor = tf.expand_dims(w_floor, -1)  # shape = (batch, *loc_shape, 1)
311
            w_ceil = tf.expand_dims(w_ceil, -1)  # shape = (batch, *loc_shape, 1)
312
313
        loc_floor_ceil.append([tf.cast(c_floor, tf.int32), tf.cast(c_ceil, tf.int32)])
314
        weight_floor.append(w_floor)
315
        weight_ceil.append(w_ceil)
316
317
    # 2**n corners, each is a list of n binary values
318
    corner_indices = get_n_bits_combinations(num_bits=len(vol_shape))
319
320
    # batch_coords[b, l1, ..., lm] = b
321
    # range(batch_size) on axis 0 and repeated on other axes
322
    # add batch coords manually is faster than using batch_dims in tf.gather_nd
323
    batch_coords = tf.tile(
324
        tf.reshape(tf.range(batch_size), [batch_size] + [1] * len(loc_shape)),
325
        [1] + loc_shape,
326
    )  # shape = (batch, *loc_shape)
327
328
    # get vol values on n-dim hypercube corners
329
    # corner_values has 2 ** n elements
330
    # each of shape (batch, *loc_shape) or (batch, *loc_shape, ch)
331
    corner_values = [
332
        tf.gather_nd(
333
            vol,  # shape = (batch, *vol_shape) or (batch, *vol_shape, ch)
334
            tf.stack(
335
                [batch_coords]
336
                + [loc_floor_ceil[axis][fc_idx] for axis, fc_idx in enumerate(c)],
337
                axis=-1,
338
            ),  # shape = (batch, *loc_shape, n+1) after stack
339
        )
340
        for c in corner_indices  # c is list of len n
341
    ]  # each tensor has shape (batch, *loc_shape) or (batch, *loc_shape, ch)
342
343
    # resample
344
    sampled = pyramid_combination(
345
        values=corner_values, weight_floor=weight_floor, weight_ceil=weight_ceil
346
    )
347
    return sampled
348
349
350
def warp_grid(grid: tf.Tensor, theta: tf.Tensor) -> tf.Tensor:
351
    """
352
    Perform transformation on the grid.
353
354
    - grid_padded[i,j,k,:] = [i j k 1]
355
    - grid_warped[b,i,j,k,p] = sum_over_q (grid_padded[i,j,k,q] * theta[b,q,p])
356
357
    :param grid: shape = (dim1, dim2, dim3, 3), grid[i,j,k,:] = [i j k]
358
    :param theta: parameters of transformation, shape = (batch, 4, 3)
359
    :return: shape = (batch, dim1, dim2, dim3, 3)
360
    """
361
362
    # grid_padded[i,j,k,:] = [i j k 1], shape = (dim1, dim2, dim3, 4)
363
    grid_padded = tf.concat([grid, tf.ones_like(grid[..., :1])], axis=3)
364
365
    # grid_warped[b,i,j,k,p] = sum_over_q (grid_padded[i,j,k,q] * theta[b,q,p])
366
    # shape = (batch, dim1, dim2, dim3, 3)
367
    grid_warped = tf.einsum("ijkq,bqp->bijkp", grid_padded, theta)
368
    return grid_warped
369
370
371
def gaussian_filter_3d(kernel_sigma: (list, tuple, int)) -> tf.Tensor:
372
    """
373
    Define a gaussian filter in 3d for smoothing.
374
375
    The filter size is defined 3*kernel_sigma
376
377
378
    :param kernel_sigma: the deviation at each direction (list)
379
        or use an isotropic deviation (int)
380
    :return: kernel: tf.Tensor specify a gaussian kernel of shape:
381
        [3*k for k in kernel_sigma]
382
    """
383
    if isinstance(kernel_sigma, int):
384
        kernel_sigma = (kernel_sigma, kernel_sigma, kernel_sigma)
385
386
    kernel_size = [
387
        int(np.ceil(ks * 3) + np.mod(np.ceil(ks * 3) + 1, 2)) for ks in kernel_sigma
388
    ]
389
390
    # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
391
    coord = [np.arange(ks) for ks in kernel_size]
392
393
    xx, yy, zz = np.meshgrid(coord[0], coord[1], coord[2], indexing="ij")
394
    xyz_grid = np.concatenate(
395
        (xx[np.newaxis], yy[np.newaxis], zz[np.newaxis]), axis=0
396
    )  # 2, y, x
397
398
    mean = np.asarray([(ks - 1) / 2.0 for ks in kernel_size])
399
    mean = mean.reshape(-1, 1, 1, 1)
400
    variance = np.asarray([ks ** 2.0 for ks in kernel_sigma])
401
    variance = variance.reshape(-1, 1, 1, 1)
402
403
    # Calculate the 2-dimensional gaussian kernel which is
404
    # the product of two gaussian distributions for two different
405
    # variables (in this case called x and y)
406
    # 2.506628274631 = sqrt(2 * pi)
407
408
    norm_kernel = 1.0 / (np.sqrt(2 * np.pi) ** 3 + np.prod(kernel_sigma))
409
    kernel = norm_kernel * np.exp(
410
        -np.sum((xyz_grid - mean) ** 2.0 / (2 * variance), axis=0)
411
    )
412
413
    # Make sure sum of values in gaussian kernel equals 1.
414
    kernel = kernel / np.sum(kernel)
415
416
    # Reshape
417
    kernel = kernel.reshape(kernel_size[0], kernel_size[1], kernel_size[2])
418
419
    # Total kernel
420
    total_kernel = np.zeros(tuple(kernel_size) + (3, 3))
421
    total_kernel[..., 0, 0] = kernel
422
    total_kernel[..., 1, 1] = kernel
423
    total_kernel[..., 2, 2] = kernel
424
425
    return tf.convert_to_tensor(total_kernel, dtype=tf.float32)
426