Passed
Pull Request — main (#733)
by Yunguan
01:34
created

LocalNormalizedCrossCorrelation.calc_ncc()   A

Complexity

Conditions 2

Size

Total Lines 47
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 21
dl 0
loc 47
rs 9.376
c 0
b 0
f 0
cc 2
nop 3
1
"""Provide different loss or metrics classes for images."""
2
import tensorflow as tf
3
4
from deepreg.constant import EPS
5
from deepreg.loss.util import NegativeLossMixin
6
from deepreg.loss.util import gaussian_kernel1d_size as gaussian_kernel1d
7
from deepreg.loss.util import (
8
    rectangular_kernel1d,
9
    separable_filter,
10
    triangular_kernel1d,
11
)
12
from deepreg.registry import REGISTRY
13
14
15
@REGISTRY.register_loss(name="ssd")
16
class SumSquaredDifference(tf.keras.losses.Loss):
17
    """
18
    Sum of squared distance between y_true and y_pred.
19
20
    y_true and y_pred have to be at least 1d tensor, including batch axis.
21
    """
22
23
    def __init__(
24
        self,
25
        reduction: str = tf.keras.losses.Reduction.SUM,
26
        name: str = "SumSquaredDifference",
27
    ):
28
        """
29
        Init.
30
31
        :param reduction: using SUM reduction over batch axis,
32
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
33
        :param name: name of the loss
34
        """
35
        super().__init__(reduction=reduction, name=name)
36
37
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
38
        """
39
        Return loss for a batch.
40
41
        :param y_true: shape = (batch, ...)
42
        :param y_pred: shape = (batch, ...)
43
        :return: shape = (batch,)
44
        """
45
        loss = tf.math.squared_difference(y_true, y_pred)
46
        loss = tf.keras.layers.Flatten()(loss)
47
        return tf.reduce_mean(loss, axis=1)
48
49
50
class GlobalMutualInformation(tf.keras.losses.Loss):
51
    """
52
    Differentiable global mutual information via Parzen windowing method.
53
54
    y_true and y_pred have to be at least 4d tensor, including batch axis.
55
56
    Reference: https://dspace.mit.edu/handle/1721.1/123142,
57
        Section 3.1, equation 3.1-3.5, Algorithm 1
58
    """
59
60
    def __init__(
61
        self,
62
        num_bins: int = 23,
63
        sigma_ratio: float = 0.5,
64
        reduction: str = tf.keras.losses.Reduction.SUM,
65
        name: str = "GlobalMutualInformation",
66
    ):
67
        """
68
        Init.
69
70
        :param num_bins: number of bins for intensity, the default value is empirical.
71
        :param sigma_ratio: a hyper param for gaussian function
72
        :param reduction: using SUM reduction over batch axis,
73
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
74
        :param name: name of the loss
75
        """
76
        super().__init__(reduction=reduction, name=name)
77
        self.num_bins = num_bins
78
        self.sigma_ratio = sigma_ratio
79
80
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
81
        """
82
        Return loss for a batch.
83
84
        :param y_true: shape = (batch, dim1, dim2, dim3)
85
            or (batch, dim1, dim2, dim3, ch)
86
        :param y_pred: shape = (batch, dim1, dim2, dim3)
87
            or (batch, dim1, dim2, dim3, ch)
88
        :return: shape = (batch,)
89
        """
90
        # adjust
91
        if len(y_true.shape) == 4:
92
            y_true = tf.expand_dims(y_true, axis=4)
93
            y_pred = tf.expand_dims(y_pred, axis=4)
94
        assert len(y_true.shape) == len(y_pred.shape) == 5
95
96
        # intensity is split into bins between 0, 1
97
        y_true = tf.clip_by_value(y_true, 0, 1)
98
        y_pred = tf.clip_by_value(y_pred, 0, 1)
99
        bin_centers = tf.linspace(0.0, 1.0, self.num_bins)  # (num_bins,)
100
        bin_centers = tf.cast(bin_centers, dtype=y_true.dtype)
101
        bin_centers = bin_centers[None, None, ...]  # (1, 1, num_bins)
102
        sigma = (
103
            tf.reduce_mean(bin_centers[:, :, 1:] - bin_centers[:, :, :-1])
104
            * self.sigma_ratio
105
        )  # scalar, sigma in the Gaussian function (weighting function W)
106
        preterm = 1 / (2 * tf.math.square(sigma))  # scalar
107
        batch, w, h, z, c = y_true.shape
108
        y_true = tf.reshape(y_true, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
109
        y_pred = tf.reshape(y_pred, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
110
        nb_voxels = y_true.shape[1] * 1.0  # w * h * z, number of voxels
111
112
        # each voxel contributes continuously to a range of histogram bin
113
        ia = tf.math.exp(
114
            -preterm * tf.math.square(y_true - bin_centers)
115
        )  # (batch, nb_voxels, num_bins)
116
        ia /= tf.reduce_sum(ia, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
117
        ia = tf.transpose(ia, (0, 2, 1))  # (batch, num_bins, nb_voxels)
118
        pa = tf.reduce_mean(ia, axis=-1, keepdims=True)  # (batch, num_bins, 1)
119
120
        ib = tf.math.exp(
121
            -preterm * tf.math.square(y_pred - bin_centers)
122
        )  # (batch, nb_voxels, num_bins)
123
        ib /= tf.reduce_sum(ib, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
124
        pb = tf.reduce_mean(ib, axis=1, keepdims=True)  # (batch, 1, num_bins)
125
126
        papb = tf.matmul(pa, pb)  # (batch, num_bins, num_bins)
127
        pab = tf.matmul(ia, ib)  # (batch, num_bins, num_bins)
128
        pab /= nb_voxels
129
130
        # MI: sum(P_ab * log(P_ab/P_ap_b))
131
        div = (pab + EPS) / (papb + EPS)
132
        return tf.reduce_sum(pab * tf.math.log(div + EPS), axis=[1, 2])
133
134
    def get_config(self) -> dict:
135
        """Return the config dictionary for recreating this class."""
136
        config = super().get_config()
137
        config["num_bins"] = self.num_bins
138
        config["sigma_ratio"] = self.sigma_ratio
139
        return config
140
141
142
@REGISTRY.register_loss(name="gmi")
143
class GlobalMutualInformationLoss(NegativeLossMixin, GlobalMutualInformation):
144
    """Revert the sign of GlobalMutualInformation."""
145
146
147
class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss):
148
    """
149
    Local squared zero-normalized cross-correlation.
150
151
    Denote y_true as t and y_pred as p. Consider a window having n elements.
152
    Each position in the window corresponds a weight w_i for i=1:n.
153
154
    Define the discrete expectation in the window E[t] as
155
156
        E[t] = sum_i(w_i * t_i) / sum_i(w_i)
157
158
    Similarly, the discrete variance in the window V[t] is
159
160
        V[t] = E[t**2] - E[t] ** 2
161
162
    The local squared zero-normalized cross-correlation is therefore
163
164
        E[ (t-E[t]) * (p-E[p]) ] ** 2 / V[t] / V[p]
165
166
    where the expectation in numerator is
167
168
        E[ (t-E[t]) * (p-E[p]) ] = E[t * p] - E[t] * E[p]
169
170
    Different kernel corresponds to different weights.
171
172
    For now, y_true and y_pred have to be at least 4d tensor, including batch axis.
173
174
    Reference:
175
176
        - Zero-normalized cross-correlation (ZNCC):
177
            https://en.wikipedia.org/wiki/Cross-correlation
178
        - Code: https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
179
    """
180
181
    kernel_fn_dict = dict(
182
        gaussian=gaussian_kernel1d,
183
        rectangular=rectangular_kernel1d,
184
        triangular=triangular_kernel1d,
185
    )
186
187
    def __init__(
188
        self,
189
        kernel_size: int = 9,
190
        kernel_type: str = "rectangular",
191
        reduction: str = tf.keras.losses.Reduction.SUM,
192
        name: str = "LocalNormalizedCrossCorrelation",
193
    ):
194
        """
195
        Init.
196
197
        :param kernel_size: int. Kernel size or kernel sigma for kernel_type='gauss'.
198
        :param kernel_type: str, rectangular, triangular or gaussian
199
        :param reduction: using SUM reduction over batch axis,
200
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
201
        :param name: name of the loss
202
        """
203
        super().__init__(reduction=reduction, name=name)
204
        if kernel_type not in self.kernel_fn_dict.keys():
205
            raise ValueError(
206
                f"Wrong kernel_type {kernel_type} for LNCC loss type. "
207
                f"Feasible values are {self.kernel_fn_dict.keys()}"
208
            )
209
        self.kernel_fn = self.kernel_fn_dict[kernel_type]
210
        self.kernel_type = kernel_type
211
        self.kernel_size = kernel_size
212
213
        # (kernel_size, )
214
        self.kernel = self.kernel_fn(kernel_size=self.kernel_size)
215
        # E[1] = sum_i(w_i), ()
216
        self.kernel_vol = tf.reduce_sum(
217
            self.kernel[:, None, None]
218
            * self.kernel[None, :, None]
219
            * self.kernel[None, None, :]
220
        )
221
222
    def calc_ncc(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
223
        """
224
        Return NCC for a batch.
225
226
        :param y_true: shape = (batch, dim1, dim2, dim3)
227
            or (batch, dim1, dim2, dim3, 1)
228
        :param y_pred: shape = (batch, dim1, dim2, dim3)
229
            or (batch, dim1, dim2, dim3, 1)
230
        :return: shape = (batch, dim1, dim2, dim3. 1)
231
        """
232
        # adjust
233
        if len(y_true.shape) == 4:
234
            y_true = tf.expand_dims(y_true, axis=4)
235
            y_pred = tf.expand_dims(y_pred, axis=4)
236
237
        # t = y_true, p = y_pred
238
        # (batch, dim1, dim2, dim3, ch)
239
        t2 = y_true * y_true
240
        p2 = y_pred * y_pred
241
        tp = y_true * y_pred
242
243
        # sum over kernel
244
        # (batch, dim1, dim2, dim3, 1)
245
        t_sum = separable_filter(y_true, kernel=self.kernel)  # E[t] * E[1]
246
        p_sum = separable_filter(y_pred, kernel=self.kernel)  # E[p] * E[1]
247
        t2_sum = separable_filter(t2, kernel=self.kernel)  # E[tt] * E[1]
248
        p2_sum = separable_filter(p2, kernel=self.kernel)  # E[pp] * E[1]
249
        tp_sum = separable_filter(tp, kernel=self.kernel)  # E[tp] * E[1]
250
251
        # average over kernel
252
        # (batch, dim1, dim2, dim3, 1)
253
        t_avg = t_sum / self.kernel_vol  # E[t]
254
        p_avg = p_sum / self.kernel_vol  # E[p]
255
256
        # shape = (batch, dim1, dim2, dim3, 1)
257
        cross = tp_sum - p_avg * t_sum  # E[tp] * E[1] - E[p] * E[t] * E[1]
258
        t_var = t2_sum - t_avg * t_sum  # V[t] * E[1]
259
        p_var = p2_sum - p_avg * p_sum  # V[p] * E[1]
260
261
        # ensure variance >= 0
262
        t_var = tf.maximum(t_var, 0)
263
        p_var = tf.maximum(p_var, 0)
264
265
        # (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p]
266
        ncc = (cross * cross + EPS) / (t_var * p_var + EPS)
267
268
        return ncc
269
270
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
271
        """
272
        Return loss for a batch.
273
274
        :param y_true: shape = (batch, dim1, dim2, dim3)
275
            or (batch, dim1, dim2, dim3, ch)
276
        :param y_pred: shape = (batch, dim1, dim2, dim3)
277
            or (batch, dim1, dim2, dim3, ch)
278
        :return: shape = (batch,)
279
        """
280
        ncc = self.calc_ncc(y_true=y_true, y_pred=y_pred)
281
        return tf.reduce_mean(ncc, axis=[1, 2, 3, 4])
282
283
    def get_config(self) -> dict:
284
        """Return the config dictionary for recreating this class."""
285
        config = super().get_config()
286
        config["kernel_size"] = self.kernel_size
287
        config["kernel_type"] = self.kernel_type
288
        return config
289
290
291
@REGISTRY.register_loss(name="lncc")
292
class LocalNormalizedCrossCorrelationLoss(
293
    NegativeLossMixin, LocalNormalizedCrossCorrelation
294
):
295
    """Revert the sign of LocalNormalizedCrossCorrelation."""
296
297
298
class GlobalNormalizedCrossCorrelation(tf.keras.losses.Loss):
299
    """
300
    Global squared zero-normalized cross-correlation.
301
302
    Compute the squared cross-correlation between the reference and moving images
303
    y_true and y_pred have to be at least 4d tensor, including batch axis.
304
305
    Reference:
306
307
        - Zero-normalized cross-correlation (ZNCC):
308
            https://en.wikipedia.org/wiki/Cross-correlation
309
310
    """
311
312
    def __init__(
313
        self,
314
        reduction: str = tf.keras.losses.Reduction.AUTO,
315
        name: str = "GlobalNormalizedCrossCorrelation",
316
    ):
317
        """
318
        Init.
319
        :param reduction: using AUTO reduction,
320
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
321
        :param name: name of the loss
322
        """
323
        super().__init__(reduction=reduction, name=name)
324
325
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
326
        """
327
        Return loss for a batch.
328
329
        :param y_true: shape = (batch, ...)
330
        :param y_pred: shape = (batch, ...)
331
        :return: shape = (batch,)
332
        """
333
334
        axis = [a for a in range(1, len(y_true.shape))]
335
        mu_pred = tf.reduce_mean(y_pred, axis=axis, keepdims=True)
336
        mu_true = tf.reduce_mean(y_true, axis=axis, keepdims=True)
337
        var_pred = tf.math.reduce_variance(y_pred, axis=axis)
338
        var_true = tf.math.reduce_variance(y_true, axis=axis)
339
        numerator = tf.abs(
340
            tf.reduce_mean((y_pred - mu_pred) * (y_true - mu_true), axis=axis)
341
        )
342
343
        return (numerator * numerator + EPS) / (var_pred * var_true + EPS)
344
345
346
@REGISTRY.register_loss(name="gncc")
347
class GlobalNormalizedCrossCorrelationLoss(
348
    NegativeLossMixin, GlobalNormalizedCrossCorrelation
349
):
350
    """Revert the sign of GlobalNormalizedCrossCorrelation."""
351