Passed
Pull Request — main (#719)
by Yunguan
01:41
created

LocalNormalizedCrossCorrelation._call()   A

Complexity

Conditions 2

Size

Total Lines 42
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 42
rs 9.4
c 0
b 0
f 0
cc 2
nop 3
1
"""Provide different loss or metrics classes for images."""
2
from typing import Tuple
3
4
import tensorflow as tf
5
6
from deepreg.loss.util import NegativeLossMixin
7
from deepreg.loss.util import gaussian_kernel1d_size as gaussian_kernel1d
8
from deepreg.loss.util import (
9
    rectangular_kernel1d,
10
    separable_filter,
11
    triangular_kernel1d,
12
)
13
from deepreg.registry import REGISTRY
14
15
EPS = 1.0e-5
16
17
18
@REGISTRY.register_loss(name="ssd")
19
class SumSquaredDifference(tf.keras.losses.Loss):
20
    """
21
    Sum of squared distance between y_true and y_pred.
22
23
    y_true and y_pred have to be at least 1d tensor, including batch axis.
24
    """
25
26
    def __init__(
27
        self,
28
        reduction: str = tf.keras.losses.Reduction.SUM,
29
        name: str = "SumSquaredDifference",
30
    ):
31
        """
32
        Init.
33
34
        :param reduction: using SUM reduction over batch axis,
35
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
36
        :param name: name of the loss
37
        """
38
        super().__init__(reduction=reduction, name=name)
39
40
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
41
        """
42
        Return loss for a batch.
43
44
        :param y_true: shape = (batch, ...)
45
        :param y_pred: shape = (batch, ...)
46
        :return: shape = (batch,)
47
        """
48
        loss = tf.math.squared_difference(y_true, y_pred)
49
        loss = tf.keras.layers.Flatten()(loss)
50
        return tf.reduce_mean(loss, axis=1)
51
52
53
class GlobalMutualInformation(tf.keras.losses.Loss):
54
    """
55
    Differentiable global mutual information via Parzen windowing method.
56
57
    y_true and y_pred have to be at least 4d tensor, including batch axis.
58
59
    Reference: https://dspace.mit.edu/handle/1721.1/123142,
60
        Section 3.1, equation 3.1-3.5, Algorithm 1
61
    """
62
63
    def __init__(
64
        self,
65
        num_bins: int = 23,
66
        sigma_ratio: float = 0.5,
67
        reduction: str = tf.keras.losses.Reduction.SUM,
68
        name: str = "GlobalMutualInformation",
69
    ):
70
        """
71
        Init.
72
73
        :param num_bins: number of bins for intensity, the default value is empirical.
74
        :param sigma_ratio: a hyper param for gaussian function
75
        :param reduction: using SUM reduction over batch axis,
76
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
77
        :param name: name of the loss
78
        """
79
        super().__init__(reduction=reduction, name=name)
80
        self.num_bins = num_bins
81
        self.sigma_ratio = sigma_ratio
82
83
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
84
        """
85
        Return loss for a batch.
86
87
        :param y_true: shape = (batch, dim1, dim2, dim3)
88
            or (batch, dim1, dim2, dim3, ch)
89
        :param y_pred: shape = (batch, dim1, dim2, dim3)
90
            or (batch, dim1, dim2, dim3, ch)
91
        :return: shape = (batch,)
92
        """
93
        # adjust
94
        if len(y_true.shape) == 4:
95
            y_true = tf.expand_dims(y_true, axis=4)
96
            y_pred = tf.expand_dims(y_pred, axis=4)
97
        assert len(y_true.shape) == len(y_pred.shape) == 5
98
99
        # intensity is split into bins between 0, 1
100
        y_true = tf.clip_by_value(y_true, 0, 1)
101
        y_pred = tf.clip_by_value(y_pred, 0, 1)
102
        bin_centers = tf.linspace(0.0, 1.0, self.num_bins)  # (num_bins,)
103
        bin_centers = tf.cast(bin_centers, dtype=y_true.dtype)
104
        bin_centers = bin_centers[None, None, ...]  # (1, 1, num_bins)
105
        sigma = (
106
            tf.reduce_mean(bin_centers[:, :, 1:] - bin_centers[:, :, :-1])
107
            * self.sigma_ratio
108
        )  # scalar, sigma in the Gaussian function (weighting function W)
109
        preterm = 1 / (2 * tf.math.square(sigma))  # scalar
110
        batch, w, h, z, c = y_true.shape
111
        y_true = tf.reshape(y_true, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
112
        y_pred = tf.reshape(y_pred, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
113
        nb_voxels = y_true.shape[1] * 1.0  # w * h * z, number of voxels
114
115
        # each voxel contributes continuously to a range of histogram bin
116
        ia = tf.math.exp(
117
            -preterm * tf.math.square(y_true - bin_centers)
118
        )  # (batch, nb_voxels, num_bins)
119
        ia /= tf.reduce_sum(ia, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
120
        ia = tf.transpose(ia, (0, 2, 1))  # (batch, num_bins, nb_voxels)
121
        pa = tf.reduce_mean(ia, axis=-1, keepdims=True)  # (batch, num_bins, 1)
122
123
        ib = tf.math.exp(
124
            -preterm * tf.math.square(y_pred - bin_centers)
125
        )  # (batch, nb_voxels, num_bins)
126
        ib /= tf.reduce_sum(ib, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
127
        pb = tf.reduce_mean(ib, axis=1, keepdims=True)  # (batch, 1, num_bins)
128
129
        papb = tf.matmul(pa, pb)  # (batch, num_bins, num_bins)
130
        pab = tf.matmul(ia, ib)  # (batch, num_bins, num_bins)
131
        pab /= nb_voxels
132
133
        # MI: sum(P_ab * log(P_ab/P_ap_b))
134
        div = (pab + EPS) / (papb + EPS)
135
        return tf.reduce_sum(pab * tf.math.log(div + EPS), axis=[1, 2])
136
137
    def get_config(self) -> dict:
138
        """Return the config dictionary for recreating this class."""
139
        config = super().get_config()
140
        config["num_bins"] = self.num_bins
141
        config["sigma_ratio"] = self.sigma_ratio
142
        return config
143
144
145
@REGISTRY.register_loss(name="gmi")
146
class GlobalMutualInformationLoss(NegativeLossMixin, GlobalMutualInformation):
147
    """Revert the sign of GlobalMutualInformation."""
148
149
150
class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss):
151
    """
152
    Local squared zero-normalized cross-correlation.
153
154
    Denote y_true as t and y_pred as p. Consider a window having n elements.
155
    Each position in the window corresponds a weight w_i for i=1:n.
156
157
    Define the discrete expectation in the window E[t] as
158
159
        E[t] = sum_i(w_i * t_i) / sum_i(w_i)
160
161
    Similarly, the discrete variance in the window V[t] is
162
163
        V[t] = E[t**2] - E[t] ** 2
164
165
    The local squared zero-normalized cross-correlation is therefore
166
167
        E[ (t-E[t]) * (p-E[p]) ] ** 2 / V[t] / V[p]
168
169
    where the expectation in numerator is
170
171
        E[ (t-E[t]) * (p-E[p]) ] = E[t * p] - E[t] * E[p]
172
173
    Different kernel corresponds to different weights.
174
175
    For now, y_true and y_pred have to be at least 4d tensor, including batch axis.
176
177
    Reference:
178
179
        - Zero-normalized cross-correlation (ZNCC):
180
            https://en.wikipedia.org/wiki/Cross-correlation
181
        - Code: https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
182
    """
183
184
    kernel_fn_dict = dict(
185
        gaussian=gaussian_kernel1d,
186
        rectangular=rectangular_kernel1d,
187
        triangular=triangular_kernel1d,
188
    )
189
190
    def __init__(
191
        self,
192
        kernel_size: int = 9,
193
        kernel_type: str = "rectangular",
194
        reduction: str = tf.keras.losses.Reduction.SUM,
195
        name: str = "LocalNormalizedCrossCorrelation",
196
    ):
197
        """
198
        Init.
199
200
        :param kernel_size: int. Kernel size or kernel sigma for kernel_type='gauss'.
201
        :param kernel_type: str, rectangular, triangular or gaussian
202
        :param reduction: using SUM reduction over batch axis,
203
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
204
        :param name: name of the loss
205
        """
206
        super().__init__(reduction=reduction, name=name)
207
        if kernel_type not in self.kernel_fn_dict.keys():
208
            raise ValueError(
209
                f"Wrong kernel_type {kernel_type} for LNCC loss type. "
210
                f"Feasible values are {self.kernel_fn_dict.keys()}"
211
            )
212
        self.kernel_fn = self.kernel_fn_dict[kernel_type]
213
        self.kernel_type = kernel_type
214
        self.kernel_size = kernel_size
215
216
        # (kernel_size, )
217
        self.kernel = self.kernel_fn(kernel_size=self.kernel_size)
218
219
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
220
        """
221
        Return loss for a batch.
222
223
        :param y_true: shape = (batch, dim1, dim2, dim3)
224
            or (batch, dim1, dim2, dim3, ch)
225
        :param y_pred: shape = (batch, dim1, dim2, dim3)
226
            or (batch, dim1, dim2, dim3, ch)
227
        :return: shape = (batch,)
228
        """
229
        num, denom = self._call(y_true, y_pred)
230
        num = num + EPS
231
        denom = denom + EPS
232
        ncc = num / denom
233
234
        ncc = tf.debugging.check_numerics(
235
            ncc, "LNCC ncc value NAN/INF", name="LNCC_before_mean"
236
        )
237
238
        return tf.reduce_mean(ncc, axis=[1, 2, 3, 4])
239
240
    def _call(
241
        self, y_true: tf.Tensor, y_pred: tf.Tensor
242
    ) -> Tuple[tf.Tensor, tf.Tensor]:
243
        """
244
        Return loss for a batch.
245
246
        :param y_true: shape = (batch, dim1, dim2, dim3)
247
            or (batch, dim1, dim2, dim3, ch)
248
        :param y_pred: shape = (batch, dim1, dim2, dim3)
249
            or (batch, dim1, dim2, dim3, ch)
250
        :return: shape = (batch,)
251
        """
252
        # adjust
253
        if len(y_true.shape) == 4:
254
            y_true = tf.expand_dims(y_true, axis=4)
255
            y_pred = tf.expand_dims(y_pred, axis=4)
256
        assert len(y_true.shape) == len(y_pred.shape) == 5
257
258
        # t = y_true, p = y_pred
259
        # (batch, dim1, dim2, dim3, ch)
260
        t2 = y_true * y_true
261
        p2 = y_pred * y_pred
262
        tp = y_true * y_pred
263
264
        # sum over kernel
265
        # (batch, dim1, dim2, dim3, 1)
266
        t_sum = separable_filter(y_true, kernel=self.kernel)
267
        p_sum = separable_filter(y_pred, kernel=self.kernel)
268
        t2_sum = separable_filter(t2, kernel=self.kernel)
269
        p2_sum = separable_filter(p2, kernel=self.kernel)
270
        tp_sum = separable_filter(tp, kernel=self.kernel)
271
272
        # shape = (batch, dim1, dim2, dim3, 1)
273
        cross = tp_sum - p_sum * t_sum
274
        t_var = t2_sum - t_sum * t_sum
275
        p_var = p2_sum - p_sum * p_sum
276
277
        # (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p]
278
        num = cross * cross
279
        denom = t_var * p_var
280
281
        return num, denom
282
283
    def get_config(self) -> dict:
284
        """Return the config dictionary for recreating this class."""
285
        config = super().get_config()
286
        config["kernel_size"] = self.kernel_size
287
        config["kernel_type"] = self.kernel_type
288
        return config
289
290
291
@REGISTRY.register_loss(name="lncc")
292
class LocalNormalizedCrossCorrelationLoss(
293
    NegativeLossMixin, LocalNormalizedCrossCorrelation
294
):
295
    """Revert the sign of LocalNormalizedCrossCorrelation."""
296
297
298
class GlobalNormalizedCrossCorrelation(tf.keras.losses.Loss):
299
    """
300
    Global squared zero-normalized cross-correlation.
301
302
    Compute the squared cross-correlation between the reference and moving images
303
    y_true and y_pred have to be at least 4d tensor, including batch axis.
304
305
    Reference:
306
307
        - Zero-normalized cross-correlation (ZNCC):
308
            https://en.wikipedia.org/wiki/Cross-correlation
309
310
    """
311
312
    def __init__(
313
        self,
314
        reduction: str = tf.keras.losses.Reduction.AUTO,
315
        name: str = "GlobalNormalizedCrossCorrelation",
316
    ):
317
        """
318
        Init.
319
        :param reduction: using AUTO reduction,
320
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
321
        :param name: name of the loss
322
        """
323
        super().__init__(reduction=reduction, name=name)
324
325
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
326
        """
327
        Return loss for a batch.
328
329
        :param y_true: shape = (batch, ...)
330
        :param y_pred: shape = (batch, ...)
331
        :return: shape = (batch,)
332
        """
333
334
        axis = [a for a in range(1, len(y_true.shape))]
335
        mu_pred = tf.reduce_mean(y_pred, axis=axis, keepdims=True)
336
        mu_true = tf.reduce_mean(y_true, axis=axis, keepdims=True)
337
        var_pred = tf.math.reduce_variance(y_pred, axis=axis)
338
        var_true = tf.math.reduce_variance(y_true, axis=axis)
339
        numerator = tf.abs(
340
            tf.reduce_mean((y_pred - mu_pred) * (y_true - mu_true), axis=axis)
341
        )
342
343
        return (numerator * numerator + EPS) / (var_pred * var_true + EPS)
344
345
346
@REGISTRY.register_loss(name="gncc")
347
class GlobalNormalizedCrossCorrelationLoss(
348
    NegativeLossMixin, GlobalNormalizedCrossCorrelation
349
):
350
    """Revert the sign of GlobalNormalizedCrossCorrelation."""
351