deepreg.loss.deform   A
last analyzed

Complexity

Total Complexity 10

Size/Duplication

Total Lines 154
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 10
eloc 50
dl 0
loc 154
rs 10
c 0
b 0
f 0

4 Functions

Rating   Name   Duplication   Size   Complexity  
A gradient_dy() 0 11 1
A gradient_dz() 0 11 1
A gradient_dxyz() 0 11 1
A gradient_dx() 0 11 1

5 Methods

Rating   Name   Duplication   Size   Complexity  
A GradientNorm.get_config() 0 5 1
A BendingEnergy.call() 0 29 1
A BendingEnergy.__init__() 0 8 1
A GradientNorm.call() 0 20 2
A GradientNorm.__init__() 0 10 1
1
"""Provide regularization functions and classes for ddf."""
2
from typing import Callable
3
4
import tensorflow as tf
5
6
from deepreg.registry import REGISTRY
7
8
9
def gradient_dx(fx: tf.Tensor) -> tf.Tensor:
10
    """
11
    Calculate gradients on x-axis of a 3D tensor using central finite difference.
12
13
    It moves the tensor along axis 1 to calculate the approximate gradient, the x axis,
14
    dx[i] = (x[i+1] - x[i-1]) / 2.
15
16
    :param fx: shape = (batch, m_dim1, m_dim2, m_dim3)
17
    :return: shape = (batch, m_dim1-2, m_dim2-2, m_dim3-2)
18
    """
19
    return (fx[:, 2:, 1:-1, 1:-1] - fx[:, :-2, 1:-1, 1:-1]) / 2
20
21
22
def gradient_dy(fy: tf.Tensor) -> tf.Tensor:
23
    """
24
    Calculate gradients on y-axis of a 3D tensor using central finite difference.
25
26
    It moves the tensor along axis 2 to calculate the approximate gradient, the y axis,
27
    dy[i] = (y[i+1] - y[i-1]) / 2.
28
29
    :param fy: shape = (batch, m_dim1, m_dim2, m_dim3)
30
    :return: shape = (batch, m_dim1-2, m_dim2-2, m_dim3-2)
31
    """
32
    return (fy[:, 1:-1, 2:, 1:-1] - fy[:, 1:-1, :-2, 1:-1]) / 2
33
34
35
def gradient_dz(fz: tf.Tensor) -> tf.Tensor:
36
    """
37
    Calculate gradients on z-axis of a 3D tensor using central finite difference.
38
39
    It moves the tensor along axis 3 to calculate the approximate gradient, the z axis,
40
    dz[i] = (z[i+1] - z[i-1]) / 2.
41
42
    :param fz: shape = (batch, m_dim1, m_dim2, m_dim3)
43
    :return: shape = (batch, m_dim1-2, m_dim2-2, m_dim3-2)
44
    """
45
    return (fz[:, 1:-1, 1:-1, 2:] - fz[:, 1:-1, 1:-1, :-2]) / 2
46
47
48
def gradient_dxyz(fxyz: tf.Tensor, fn: Callable) -> tf.Tensor:
49
    """
50
    Calculate gradients on x,y,z-axis of a tensor using central finite difference.
51
52
    The gradients are calculated along x, y, z separately then stacked together.
53
54
    :param fxyz: shape = (..., 3)
55
    :param fn: function to call
56
    :return: shape = (..., 3)
57
    """
58
    return tf.stack([fn(fxyz[..., i]) for i in [0, 1, 2]], axis=4)
59
60
61
@REGISTRY.register_loss(name="gradient")
62
class GradientNorm(tf.keras.layers.Layer):
63
    """
64
    Calculate the L1/L2 norm of ddf using central finite difference.
65
66
    y_true and y_pred have to be at least 5d tensor, including batch axis.
67
    """
68
69
    def __init__(self, l1: bool = False, name: str = "GradientNorm", **kwargs):
70
        """
71
        Init.
72
73
        :param l1: bool true if calculate L1 norm, otherwise L2 norm
74
        :param name: name of the loss
75
        :param kwargs: additional arguments.
76
        """
77
        super().__init__(name=name)
78
        self.l1 = l1
79
80
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
81
        """
82
        Return a scalar loss.
83
84
        :param inputs: shape = (batch, m_dim1, m_dim2, m_dim3, 3)
85
        :param kwargs: additional arguments.
86
        :return: shape = (batch, )
87
        """
88
        assert len(inputs.shape) == 5
89
        ddf = inputs
90
        # first order gradient
91
        # (batch, m_dim1-2, m_dim2-2, m_dim3-2, 3)
92
        dfdx = gradient_dxyz(ddf, gradient_dx)
93
        dfdy = gradient_dxyz(ddf, gradient_dy)
94
        dfdz = gradient_dxyz(ddf, gradient_dz)
95
        if self.l1:
96
            norms = tf.abs(dfdx) + tf.abs(dfdy) + tf.abs(dfdz)
97
        else:
98
            norms = dfdx ** 2 + dfdy ** 2 + dfdz ** 2
99
        return tf.reduce_mean(norms, axis=[1, 2, 3, 4])
100
101
    def get_config(self) -> dict:
102
        """Return the config dictionary for recreating this class."""
103
        config = super().get_config()
104
        config["l1"] = self.l1
105
        return config
106
107
108
@REGISTRY.register_loss(name="bending")
109
class BendingEnergy(tf.keras.layers.Layer):
110
    """
111
    Calculate the bending energy of ddf using central finite difference.
112
113
    y_true and y_pred have to be at least 5d tensor, including batch axis.
114
    """
115
116
    def __init__(self, name: str = "BendingEnergy", **kwargs):
117
        """
118
        Init.
119
120
        :param name: name of the loss.
121
        :param kwargs: additional arguments.
122
        """
123
        super().__init__(name=name)
124
125
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
126
        """
127
        Return a scalar loss.
128
129
        :param inputs: shape = (batch, m_dim1, m_dim2, m_dim3, 3)
130
        :param kwargs: additional arguments.
131
        :return: shape = (batch, )
132
        """
133
        assert len(inputs.shape) == 5
134
        ddf = inputs
135
        # first order gradient
136
        # (batch, m_dim1-2, m_dim2-2, m_dim3-2, 3)
137
        dfdx = gradient_dxyz(ddf, gradient_dx)
138
        dfdy = gradient_dxyz(ddf, gradient_dy)
139
        dfdz = gradient_dxyz(ddf, gradient_dz)
140
141
        # second order gradient
142
        # (batch, m_dim1-4, m_dim2-4, m_dim3-4, 3)
143
        dfdxx = gradient_dxyz(dfdx, gradient_dx)
144
        dfdyy = gradient_dxyz(dfdy, gradient_dy)
145
        dfdzz = gradient_dxyz(dfdz, gradient_dz)
146
        dfdxy = gradient_dxyz(dfdx, gradient_dy)
147
        dfdyz = gradient_dxyz(dfdy, gradient_dz)
148
        dfdxz = gradient_dxyz(dfdx, gradient_dz)
149
150
        # (dx + dy + dz) ** 2 = dxx + dyy + dzz + 2*(dxy + dyz + dzx)
151
        energy = dfdxx ** 2 + dfdyy ** 2 + dfdzz ** 2
152
        energy += 2 * dfdxy ** 2 + 2 * dfdxz ** 2 + 2 * dfdyz ** 2
153
        return tf.reduce_mean(energy, axis=[1, 2, 3, 4])
154