1
|
|
|
"""Provide regularization functions and classes for ddf.""" |
2
|
|
|
from typing import Callable |
3
|
|
|
|
4
|
|
|
import tensorflow as tf |
5
|
|
|
|
6
|
|
|
from deepreg.registry import REGISTRY |
7
|
|
|
|
8
|
|
|
|
9
|
|
|
def gradient_dx(fx: tf.Tensor) -> tf.Tensor: |
10
|
|
|
""" |
11
|
|
|
Calculate gradients on x-axis of a 3D tensor using central finite difference. |
12
|
|
|
|
13
|
|
|
It moves the tensor along axis 1 to calculate the approximate gradient, the x axis, |
14
|
|
|
dx[i] = (x[i+1] - x[i-1]) / 2. |
15
|
|
|
|
16
|
|
|
:param fx: shape = (batch, m_dim1, m_dim2, m_dim3) |
17
|
|
|
:return: shape = (batch, m_dim1-2, m_dim2-2, m_dim3-2) |
18
|
|
|
""" |
19
|
|
|
return (fx[:, 2:, 1:-1, 1:-1] - fx[:, :-2, 1:-1, 1:-1]) / 2 |
20
|
|
|
|
21
|
|
|
|
22
|
|
|
def gradient_dy(fy: tf.Tensor) -> tf.Tensor: |
23
|
|
|
""" |
24
|
|
|
Calculate gradients on y-axis of a 3D tensor using central finite difference. |
25
|
|
|
|
26
|
|
|
It moves the tensor along axis 2 to calculate the approximate gradient, the y axis, |
27
|
|
|
dy[i] = (y[i+1] - y[i-1]) / 2. |
28
|
|
|
|
29
|
|
|
:param fy: shape = (batch, m_dim1, m_dim2, m_dim3) |
30
|
|
|
:return: shape = (batch, m_dim1-2, m_dim2-2, m_dim3-2) |
31
|
|
|
""" |
32
|
|
|
return (fy[:, 1:-1, 2:, 1:-1] - fy[:, 1:-1, :-2, 1:-1]) / 2 |
33
|
|
|
|
34
|
|
|
|
35
|
|
|
def gradient_dz(fz: tf.Tensor) -> tf.Tensor: |
36
|
|
|
""" |
37
|
|
|
Calculate gradients on z-axis of a 3D tensor using central finite difference. |
38
|
|
|
|
39
|
|
|
It moves the tensor along axis 3 to calculate the approximate gradient, the z axis, |
40
|
|
|
dz[i] = (z[i+1] - z[i-1]) / 2. |
41
|
|
|
|
42
|
|
|
:param fz: shape = (batch, m_dim1, m_dim2, m_dim3) |
43
|
|
|
:return: shape = (batch, m_dim1-2, m_dim2-2, m_dim3-2) |
44
|
|
|
""" |
45
|
|
|
return (fz[:, 1:-1, 1:-1, 2:] - fz[:, 1:-1, 1:-1, :-2]) / 2 |
46
|
|
|
|
47
|
|
|
|
48
|
|
|
def gradient_dxyz(fxyz: tf.Tensor, fn: Callable) -> tf.Tensor: |
49
|
|
|
""" |
50
|
|
|
Calculate gradients on x,y,z-axis of a tensor using central finite difference. |
51
|
|
|
|
52
|
|
|
The gradients are calculated along x, y, z separately then stacked together. |
53
|
|
|
|
54
|
|
|
:param fxyz: shape = (..., 3) |
55
|
|
|
:param fn: function to call |
56
|
|
|
:return: shape = (..., 3) |
57
|
|
|
""" |
58
|
|
|
return tf.stack([fn(fxyz[..., i]) for i in [0, 1, 2]], axis=4) |
59
|
|
|
|
60
|
|
|
|
61
|
|
|
@REGISTRY.register_loss(name="gradient") |
62
|
|
|
class GradientNorm(tf.keras.layers.Layer): |
63
|
|
|
""" |
64
|
|
|
Calculate the L1/L2 norm of ddf using central finite difference. |
65
|
|
|
|
66
|
|
|
y_true and y_pred have to be at least 5d tensor, including batch axis. |
67
|
|
|
""" |
68
|
|
|
|
69
|
|
|
def __init__(self, l1: bool = False, name: str = "GradientNorm", **kwargs): |
70
|
|
|
""" |
71
|
|
|
Init. |
72
|
|
|
|
73
|
|
|
:param l1: bool true if calculate L1 norm, otherwise L2 norm |
74
|
|
|
:param name: name of the loss |
75
|
|
|
:param kwargs: additional arguments. |
76
|
|
|
""" |
77
|
|
|
super().__init__(name=name) |
78
|
|
|
self.l1 = l1 |
79
|
|
|
|
80
|
|
|
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: |
81
|
|
|
""" |
82
|
|
|
Return a scalar loss. |
83
|
|
|
|
84
|
|
|
:param inputs: shape = (batch, m_dim1, m_dim2, m_dim3, 3) |
85
|
|
|
:param kwargs: additional arguments. |
86
|
|
|
:return: shape = (batch, ) |
87
|
|
|
""" |
88
|
|
|
assert len(inputs.shape) == 5 |
89
|
|
|
ddf = inputs |
90
|
|
|
# first order gradient |
91
|
|
|
# (batch, m_dim1-2, m_dim2-2, m_dim3-2, 3) |
92
|
|
|
dfdx = gradient_dxyz(ddf, gradient_dx) |
93
|
|
|
dfdy = gradient_dxyz(ddf, gradient_dy) |
94
|
|
|
dfdz = gradient_dxyz(ddf, gradient_dz) |
95
|
|
|
if self.l1: |
96
|
|
|
norms = tf.abs(dfdx) + tf.abs(dfdy) + tf.abs(dfdz) |
97
|
|
|
else: |
98
|
|
|
norms = dfdx ** 2 + dfdy ** 2 + dfdz ** 2 |
99
|
|
|
return tf.reduce_mean(norms, axis=[1, 2, 3, 4]) |
100
|
|
|
|
101
|
|
|
def get_config(self) -> dict: |
102
|
|
|
"""Return the config dictionary for recreating this class.""" |
103
|
|
|
config = super().get_config() |
104
|
|
|
config["l1"] = self.l1 |
105
|
|
|
return config |
106
|
|
|
|
107
|
|
|
|
108
|
|
|
@REGISTRY.register_loss(name="bending") |
109
|
|
|
class BendingEnergy(tf.keras.layers.Layer): |
110
|
|
|
""" |
111
|
|
|
Calculate the bending energy of ddf using central finite difference. |
112
|
|
|
|
113
|
|
|
y_true and y_pred have to be at least 5d tensor, including batch axis. |
114
|
|
|
""" |
115
|
|
|
|
116
|
|
|
def __init__(self, name: str = "BendingEnergy", **kwargs): |
117
|
|
|
""" |
118
|
|
|
Init. |
119
|
|
|
|
120
|
|
|
:param name: name of the loss. |
121
|
|
|
:param kwargs: additional arguments. |
122
|
|
|
""" |
123
|
|
|
super().__init__(name=name) |
124
|
|
|
|
125
|
|
|
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: |
126
|
|
|
""" |
127
|
|
|
Return a scalar loss. |
128
|
|
|
|
129
|
|
|
:param inputs: shape = (batch, m_dim1, m_dim2, m_dim3, 3) |
130
|
|
|
:param kwargs: additional arguments. |
131
|
|
|
:return: shape = (batch, ) |
132
|
|
|
""" |
133
|
|
|
assert len(inputs.shape) == 5 |
134
|
|
|
ddf = inputs |
135
|
|
|
# first order gradient |
136
|
|
|
# (batch, m_dim1-2, m_dim2-2, m_dim3-2, 3) |
137
|
|
|
dfdx = gradient_dxyz(ddf, gradient_dx) |
138
|
|
|
dfdy = gradient_dxyz(ddf, gradient_dy) |
139
|
|
|
dfdz = gradient_dxyz(ddf, gradient_dz) |
140
|
|
|
|
141
|
|
|
# second order gradient |
142
|
|
|
# (batch, m_dim1-4, m_dim2-4, m_dim3-4, 3) |
143
|
|
|
dfdxx = gradient_dxyz(dfdx, gradient_dx) |
144
|
|
|
dfdyy = gradient_dxyz(dfdy, gradient_dy) |
145
|
|
|
dfdzz = gradient_dxyz(dfdz, gradient_dz) |
146
|
|
|
dfdxy = gradient_dxyz(dfdx, gradient_dy) |
147
|
|
|
dfdyz = gradient_dxyz(dfdy, gradient_dz) |
148
|
|
|
dfdxz = gradient_dxyz(dfdx, gradient_dz) |
149
|
|
|
|
150
|
|
|
# (dx + dy + dz) ** 2 = dxx + dyy + dzz + 2*(dxy + dyz + dzx) |
151
|
|
|
energy = dfdxx ** 2 + dfdyy ** 2 + dfdzz ** 2 |
152
|
|
|
energy += 2 * dfdxy ** 2 + 2 * dfdxz ** 2 + 2 * dfdyz ** 2 |
153
|
|
|
return tf.reduce_mean(energy, axis=[1, 2, 3, 4]) |
154
|
|
|
|