Passed
Pull Request — main (#733)
by Yunguan
01:18
created

LocalNormalizedCrossCorrelation.calc_ncc()   A

Complexity

Conditions 1

Size

Total Lines 49
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 18
dl 0
loc 49
rs 9.5
c 0
b 0
f 0
cc 1
nop 3
1
"""Provide different loss or metrics classes for images."""
2
import tensorflow as tf
3
4
from deepreg.constant import EPS
5
from deepreg.loss.util import NegativeLossMixin
6
from deepreg.loss.util import gaussian_kernel1d_size as gaussian_kernel1d
7
from deepreg.loss.util import (
8
    rectangular_kernel1d,
9
    separable_filter,
10
    triangular_kernel1d,
11
)
12
from deepreg.registry import REGISTRY
13
14
15
@REGISTRY.register_loss(name="ssd")
16
class SumSquaredDifference(tf.keras.losses.Loss):
17
    """
18
    Sum of squared distance between y_true and y_pred.
19
20
    y_true and y_pred have to be at least 1d tensor, including batch axis.
21
    """
22
23
    def __init__(
24
        self,
25
        reduction: str = tf.keras.losses.Reduction.SUM,
26
        name: str = "SumSquaredDifference",
27
    ):
28
        """
29
        Init.
30
31
        :param reduction: using SUM reduction over batch axis,
32
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
33
        :param name: name of the loss
34
        """
35
        super().__init__(reduction=reduction, name=name)
36
37
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
38
        """
39
        Return loss for a batch.
40
41
        :param y_true: shape = (batch, ...)
42
        :param y_pred: shape = (batch, ...)
43
        :return: shape = (batch,)
44
        """
45
        loss = tf.math.squared_difference(y_true, y_pred)
46
        loss = tf.keras.layers.Flatten()(loss)
47
        return tf.reduce_mean(loss, axis=1)
48
49
50
class GlobalMutualInformation(tf.keras.losses.Loss):
51
    """
52
    Differentiable global mutual information via Parzen windowing method.
53
54
    y_true and y_pred have to be at least 4d tensor, including batch axis.
55
56
    Reference: https://dspace.mit.edu/handle/1721.1/123142,
57
        Section 3.1, equation 3.1-3.5, Algorithm 1
58
    """
59
60
    def __init__(
61
        self,
62
        num_bins: int = 23,
63
        sigma_ratio: float = 0.5,
64
        reduction: str = tf.keras.losses.Reduction.SUM,
65
        name: str = "GlobalMutualInformation",
66
    ):
67
        """
68
        Init.
69
70
        :param num_bins: number of bins for intensity, the default value is empirical.
71
        :param sigma_ratio: a hyper param for gaussian function
72
        :param reduction: using SUM reduction over batch axis,
73
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
74
        :param name: name of the loss
75
        """
76
        super().__init__(reduction=reduction, name=name)
77
        self.num_bins = num_bins
78
        self.sigma_ratio = sigma_ratio
79
80
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
81
        """
82
        Return loss for a batch.
83
84
        :param y_true: shape = (batch, dim1, dim2, dim3)
85
            or (batch, dim1, dim2, dim3, ch)
86
        :param y_pred: shape = (batch, dim1, dim2, dim3)
87
            or (batch, dim1, dim2, dim3, ch)
88
        :return: shape = (batch,)
89
        """
90
        # adjust
91
        if len(y_true.shape) == 4:
92
            y_true = tf.expand_dims(y_true, axis=4)
93
            y_pred = tf.expand_dims(y_pred, axis=4)
94
        assert len(y_true.shape) == len(y_pred.shape) == 5
95
96
        # intensity is split into bins between 0, 1
97
        y_true = tf.clip_by_value(y_true, 0, 1)
98
        y_pred = tf.clip_by_value(y_pred, 0, 1)
99
        bin_centers = tf.linspace(0.0, 1.0, self.num_bins)  # (num_bins,)
100
        bin_centers = tf.cast(bin_centers, dtype=y_true.dtype)
101
        bin_centers = bin_centers[None, None, ...]  # (1, 1, num_bins)
102
        sigma = (
103
            tf.reduce_mean(bin_centers[:, :, 1:] - bin_centers[:, :, :-1])
104
            * self.sigma_ratio
105
        )  # scalar, sigma in the Gaussian function (weighting function W)
106
        preterm = 1 / (2 * tf.math.square(sigma))  # scalar
107
        batch, w, h, z, c = y_true.shape
108
        y_true = tf.reshape(y_true, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
109
        y_pred = tf.reshape(y_pred, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
110
        nb_voxels = y_true.shape[1] * 1.0  # w * h * z, number of voxels
111
112
        # each voxel contributes continuously to a range of histogram bin
113
        ia = tf.math.exp(
114
            -preterm * tf.math.square(y_true - bin_centers)
115
        )  # (batch, nb_voxels, num_bins)
116
        ia /= tf.reduce_sum(ia, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
117
        ia = tf.transpose(ia, (0, 2, 1))  # (batch, num_bins, nb_voxels)
118
        pa = tf.reduce_mean(ia, axis=-1, keepdims=True)  # (batch, num_bins, 1)
119
120
        ib = tf.math.exp(
121
            -preterm * tf.math.square(y_pred - bin_centers)
122
        )  # (batch, nb_voxels, num_bins)
123
        ib /= tf.reduce_sum(ib, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
124
        pb = tf.reduce_mean(ib, axis=1, keepdims=True)  # (batch, 1, num_bins)
125
126
        papb = tf.matmul(pa, pb)  # (batch, num_bins, num_bins)
127
        pab = tf.matmul(ia, ib)  # (batch, num_bins, num_bins)
128
        pab /= nb_voxels
129
130
        # MI: sum(P_ab * log(P_ab/P_ap_b))
131
        div = (pab + EPS) / (papb + EPS)
132
        return tf.reduce_sum(pab * tf.math.log(div + EPS), axis=[1, 2])
133
134
    def get_config(self) -> dict:
135
        """Return the config dictionary for recreating this class."""
136
        config = super().get_config()
137
        config["num_bins"] = self.num_bins
138
        config["sigma_ratio"] = self.sigma_ratio
139
        return config
140
141
142
@REGISTRY.register_loss(name="gmi")
143
class GlobalMutualInformationLoss(NegativeLossMixin, GlobalMutualInformation):
144
    """Revert the sign of GlobalMutualInformation."""
145
146
147
class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss):
148
    """
149
    Local squared zero-normalized cross-correlation.
150
151
    Denote y_true as t and y_pred as p. Consider a window having n elements.
152
    Each position in the window corresponds a weight w_i for i=1:n.
153
154
    Define the discrete expectation in the window E[t] as
155
156
        E[t] = sum_i(w_i * t_i) / sum_i(w_i)
157
158
    Similarly, the discrete variance in the window V[t] is
159
160
        V[t] = E[t**2] - E[t] ** 2
161
162
    The local squared zero-normalized cross-correlation is therefore
163
164
        E[ (t-E[t]) * (p-E[p]) ] ** 2 / V[t] / V[p]
165
166
    where the expectation in numerator is
167
168
        E[ (t-E[t]) * (p-E[p]) ] = E[t * p] - E[t] * E[p]
169
170
    Different kernel corresponds to different weights.
171
172
    For now, y_true and y_pred have to be at least 4d tensor, including batch axis.
173
174
    Reference:
175
176
        - Zero-normalized cross-correlation (ZNCC):
177
            https://en.wikipedia.org/wiki/Cross-correlation
178
        - Code: https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
179
    """
180
181
    kernel_fn_dict = dict(
182
        gaussian=gaussian_kernel1d,
183
        rectangular=rectangular_kernel1d,
184
        triangular=triangular_kernel1d,
185
    )
186
187
    def __init__(
188
        self,
189
        kernel_size: int = 9,
190
        kernel_type: str = "rectangular",
191
        smooth_nr: float = EPS,
192
        smooth_dr: float = EPS,
193
        reduction: str = tf.keras.losses.Reduction.SUM,
194
        name: str = "LocalNormalizedCrossCorrelation",
195
    ):
196
        """
197
        Init.
198
199
        :param kernel_size: int. Kernel size or kernel sigma for kernel_type='gauss'.
200
        :param kernel_type: str, rectangular, triangular or gaussian
201
        :param smooth_nr: small constant added to numerator in case of zero covariance.
202
        :param smooth_dr: small constant added to denominator in case of zero variance.
203
        :param reduction: using SUM reduction over batch axis,
204
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
205
        :param name: name of the loss
206
        """
207
        super().__init__(reduction=reduction, name=name)
208
        if kernel_type not in self.kernel_fn_dict.keys():
209
            raise ValueError(
210
                f"Wrong kernel_type {kernel_type} for LNCC loss type. "
211
                f"Feasible values are {self.kernel_fn_dict.keys()}"
212
            )
213
        self.kernel_fn = self.kernel_fn_dict[kernel_type]
214
        self.kernel_type = kernel_type
215
        self.kernel_size = kernel_size
216
        self.smooth_nr = smooth_nr
217
        self.smooth_dr = smooth_dr
218
219
        # (kernel_size, )
220
        self.kernel = self.kernel_fn(kernel_size=self.kernel_size)
221
        # E[1] = sum_i(w_i), ()
222
        self.kernel_vol = tf.reduce_sum(
223
            self.kernel[:, None, None]
224
            * self.kernel[None, :, None]
225
            * self.kernel[None, None, :]
226
        )
227
228
    def calc_ncc(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
229
        """
230
        Return NCC for a batch.
231
232
        The kernel should not be normalized, as normalizing them leads to computation
233
        with small values and the precision will be reduced.
234
        Here both numerator and denominator are actually multiplied by kernel volume,
235
        which helps the precision as well.
236
        However, when the variance is zero, the obtained value might be negative due to
237
        machine error. Therefore a hard-coded clipping is added to
238
        prevent division by zero.
239
240
        :param y_true: shape = (batch, dim1, dim2, dim3, 1)
241
        :param y_pred: shape = (batch, dim1, dim2, dim3, 1)
242
        :return: shape = (batch, dim1, dim2, dim3. 1)
243
        """
244
245
        # t = y_true, p = y_pred
246
        # (batch, dim1, dim2, dim3, 1)
247
        t2 = y_true * y_true
248
        p2 = y_pred * y_pred
249
        tp = y_true * y_pred
250
251
        # sum over kernel
252
        # (batch, dim1, dim2, dim3, 1)
253
        t_sum = separable_filter(y_true, kernel=self.kernel)  # E[t] * E[1]
254
        p_sum = separable_filter(y_pred, kernel=self.kernel)  # E[p] * E[1]
255
        t2_sum = separable_filter(t2, kernel=self.kernel)  # E[tt] * E[1]
256
        p2_sum = separable_filter(p2, kernel=self.kernel)  # E[pp] * E[1]
257
        tp_sum = separable_filter(tp, kernel=self.kernel)  # E[tp] * E[1]
258
259
        # average over kernel
260
        # (batch, dim1, dim2, dim3, 1)
261
        t_avg = t_sum / self.kernel_vol  # E[t]
262
        p_avg = p_sum / self.kernel_vol  # E[p]
263
264
        # shape = (batch, dim1, dim2, dim3, 1)
265
        cross = tp_sum - p_avg * t_sum  # E[tp] * E[1] - E[p] * E[t] * E[1]
266
        t_var = t2_sum - t_avg * t_sum  # V[t] * E[1]
267
        p_var = p2_sum - p_avg * p_sum  # V[p] * E[1]
268
269
        # ensure variance >= 0
270
        t_var = tf.maximum(t_var, 0)
271
        p_var = tf.maximum(p_var, 0)
272
273
        # (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p]
274
        ncc = (cross * cross + self.smooth_nr) / (t_var * p_var + self.smooth_dr)
275
276
        return ncc
277
278
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
279
        """
280
        Return loss for a batch.
281
282
        TODO: support channel axis dimension > 1.
283
284
        :param y_true: shape = (batch, dim1, dim2, dim3)
285
            or (batch, dim1, dim2, dim3, 1)
286
        :param y_pred: shape = (batch, dim1, dim2, dim3)
287
            or (batch, dim1, dim2, dim3, 1)
288
        :return: shape = (batch,)
289
        """
290
        # sanity checks
291
        if len(y_true.shape) == 4:
292
            y_true = tf.expand_dims(y_true, axis=4)
293
        if y_true.shape[4] != 1:
294
            raise ValueError(
295
                "Last dimension of y_true is not one. " f"y_true.shape = {y_true.shape}"
296
            )
297
        if len(y_pred.shape) == 4:
298
            y_pred = tf.expand_dims(y_pred, axis=4)
299
        if y_pred.shape[4] != 1:
300
            raise ValueError(
301
                "Last dimension of y_pred is not one. " f"y_pred.shape = {y_pred.shape}"
302
            )
303
304
        ncc = self.calc_ncc(y_true=y_true, y_pred=y_pred)
305
        return tf.reduce_mean(ncc, axis=[1, 2, 3, 4])
306
307
    def get_config(self) -> dict:
308
        """Return the config dictionary for recreating this class."""
309
        config = super().get_config()
310
        config.update(
311
            kernel_size=self.kernel_size,
312
            kernel_type=self.kernel_type,
313
            smooth_nr=self.smooth_nr,
314
            smooth_dr=self.smooth_dr,
315
        )
316
        return config
317
318
319
@REGISTRY.register_loss(name="lncc")
320
class LocalNormalizedCrossCorrelationLoss(
321
    NegativeLossMixin, LocalNormalizedCrossCorrelation
322
):
323
    """Revert the sign of LocalNormalizedCrossCorrelation."""
324
325
326
class GlobalNormalizedCrossCorrelation(tf.keras.losses.Loss):
327
    """
328
    Global squared zero-normalized cross-correlation.
329
330
    Compute the squared cross-correlation between the reference and moving images
331
    y_true and y_pred have to be at least 4d tensor, including batch axis.
332
333
    Reference:
334
335
        - Zero-normalized cross-correlation (ZNCC):
336
            https://en.wikipedia.org/wiki/Cross-correlation
337
338
    """
339
340
    def __init__(
341
        self,
342
        reduction: str = tf.keras.losses.Reduction.AUTO,
343
        name: str = "GlobalNormalizedCrossCorrelation",
344
    ):
345
        """
346
        Init.
347
        :param reduction: using AUTO reduction,
348
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
349
        :param name: name of the loss
350
        """
351
        super().__init__(reduction=reduction, name=name)
352
353
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
354
        """
355
        Return loss for a batch.
356
357
        :param y_true: shape = (batch, ...)
358
        :param y_pred: shape = (batch, ...)
359
        :return: shape = (batch,)
360
        """
361
362
        axis = [a for a in range(1, len(y_true.shape))]
363
        mu_pred = tf.reduce_mean(y_pred, axis=axis, keepdims=True)
364
        mu_true = tf.reduce_mean(y_true, axis=axis, keepdims=True)
365
        var_pred = tf.math.reduce_variance(y_pred, axis=axis)
366
        var_true = tf.math.reduce_variance(y_true, axis=axis)
367
        numerator = tf.abs(
368
            tf.reduce_mean((y_pred - mu_pred) * (y_true - mu_true), axis=axis)
369
        )
370
371
        return (numerator * numerator + EPS) / (var_pred * var_true + EPS)
372
373
374
@REGISTRY.register_loss(name="gncc")
375
class GlobalNormalizedCrossCorrelationLoss(
376
    NegativeLossMixin, GlobalNormalizedCrossCorrelation
377
):
378
    """Revert the sign of GlobalNormalizedCrossCorrelation."""
379