Completed
Push — main ( 2cddc3...bbe77b )
by Yunguan
21s queued 13s
created

LocalNormalizedCrossCorrelation.calc_ncc()   A

Complexity

Conditions 1

Size

Total Lines 49
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 18
dl 0
loc 49
rs 9.5
c 0
b 0
f 0
cc 1
nop 3
1
"""Provide different loss or metrics classes for images."""
2
import tensorflow as tf
3
4
from deepreg.constant import EPS
5
from deepreg.loss.util import NegativeLossMixin
6
from deepreg.loss.util import gaussian_kernel1d_size as gaussian_kernel1d
7
from deepreg.loss.util import (
8
    rectangular_kernel1d,
9
    separable_filter,
10
    triangular_kernel1d,
11
)
12
from deepreg.registry import REGISTRY
13
14
15
@REGISTRY.register_loss(name="ssd")
16
class SumSquaredDifference(tf.keras.losses.Loss):
17
    """
18
    Sum of squared distance between y_true and y_pred.
19
20
    y_true and y_pred have to be at least 1d tensor, including batch axis.
21
    """
22
23
    def __init__(
24
        self,
25
        reduction: str = tf.keras.losses.Reduction.SUM,
26
        name: str = "SumSquaredDifference",
27
    ):
28
        """
29
        Init.
30
31
        :param reduction: using SUM reduction over batch axis,
32
            this is for supporting multi-device training,
33
            and the loss will be divided by global batch size,
34
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
35
        :param name: name of the loss
36
        """
37
        super().__init__(reduction=reduction, name=name)
38
39
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
40
        """
41
        Return loss for a batch.
42
43
        :param y_true: shape = (batch, ...)
44
        :param y_pred: shape = (batch, ...)
45
        :return: shape = (batch,)
46
        """
47
        loss = tf.math.squared_difference(y_true, y_pred)
48
        loss = tf.keras.layers.Flatten()(loss)
49
        return tf.reduce_mean(loss, axis=1)
50
51
52
class GlobalMutualInformation(tf.keras.losses.Loss):
53
    """
54
    Differentiable global mutual information via Parzen windowing method.
55
56
    y_true and y_pred have to be at least 4d tensor, including batch axis.
57
58
    Reference: https://dspace.mit.edu/handle/1721.1/123142,
59
        Section 3.1, equation 3.1-3.5, Algorithm 1
60
    """
61
62
    def __init__(
63
        self,
64
        num_bins: int = 23,
65
        sigma_ratio: float = 0.5,
66
        reduction: str = tf.keras.losses.Reduction.SUM,
67
        name: str = "GlobalMutualInformation",
68
    ):
69
        """
70
        Init.
71
72
        :param num_bins: number of bins for intensity, the default value is empirical.
73
        :param sigma_ratio: a hyper param for gaussian function
74
        :param reduction: using SUM reduction over batch axis,
75
            this is for supporting multi-device training,
76
            and the loss will be divided by global batch size,
77
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
78
        :param name: name of the loss
79
        """
80
        super().__init__(reduction=reduction, name=name)
81
        self.num_bins = num_bins
82
        self.sigma_ratio = sigma_ratio
83
84
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
85
        """
86
        Return loss for a batch.
87
88
        :param y_true: shape = (batch, dim1, dim2, dim3)
89
            or (batch, dim1, dim2, dim3, ch)
90
        :param y_pred: shape = (batch, dim1, dim2, dim3)
91
            or (batch, dim1, dim2, dim3, ch)
92
        :return: shape = (batch,)
93
        """
94
        # adjust
95
        if len(y_true.shape) == 4:
96
            y_true = tf.expand_dims(y_true, axis=4)
97
            y_pred = tf.expand_dims(y_pred, axis=4)
98
        assert len(y_true.shape) == len(y_pred.shape) == 5
99
100
        # intensity is split into bins between 0, 1
101
        y_true = tf.clip_by_value(y_true, 0, 1)
102
        y_pred = tf.clip_by_value(y_pred, 0, 1)
103
        bin_centers = tf.linspace(0.0, 1.0, self.num_bins)  # (num_bins,)
104
        bin_centers = tf.cast(bin_centers, dtype=y_true.dtype)
105
        bin_centers = bin_centers[None, None, ...]  # (1, 1, num_bins)
106
        sigma = (
107
            tf.reduce_mean(bin_centers[:, :, 1:] - bin_centers[:, :, :-1])
108
            * self.sigma_ratio
109
        )  # scalar, sigma in the Gaussian function (weighting function W)
110
        preterm = 1 / (2 * tf.math.square(sigma))  # scalar
111
        batch, w, h, z, c = y_true.shape
112
        y_true = tf.reshape(y_true, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
113
        y_pred = tf.reshape(y_pred, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
114
        nb_voxels = y_true.shape[1] * 1.0  # w * h * z, number of voxels
115
116
        # each voxel contributes continuously to a range of histogram bin
117
        ia = tf.math.exp(
118
            -preterm * tf.math.square(y_true - bin_centers)
119
        )  # (batch, nb_voxels, num_bins)
120
        ia /= tf.reduce_sum(ia, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
121
        ia = tf.transpose(ia, (0, 2, 1))  # (batch, num_bins, nb_voxels)
122
        pa = tf.reduce_mean(ia, axis=-1, keepdims=True)  # (batch, num_bins, 1)
123
124
        ib = tf.math.exp(
125
            -preterm * tf.math.square(y_pred - bin_centers)
126
        )  # (batch, nb_voxels, num_bins)
127
        ib /= tf.reduce_sum(ib, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
128
        pb = tf.reduce_mean(ib, axis=1, keepdims=True)  # (batch, 1, num_bins)
129
130
        papb = tf.matmul(pa, pb)  # (batch, num_bins, num_bins)
131
        pab = tf.matmul(ia, ib)  # (batch, num_bins, num_bins)
132
        pab /= nb_voxels
133
134
        # MI: sum(P_ab * log(P_ab/P_ap_b))
135
        div = (pab + EPS) / (papb + EPS)
136
        return tf.reduce_sum(pab * tf.math.log(div + EPS), axis=[1, 2])
137
138
    def get_config(self) -> dict:
139
        """Return the config dictionary for recreating this class."""
140
        config = super().get_config()
141
        config["num_bins"] = self.num_bins
142
        config["sigma_ratio"] = self.sigma_ratio
143
        return config
144
145
146
@REGISTRY.register_loss(name="gmi")
147
class GlobalMutualInformationLoss(NegativeLossMixin, GlobalMutualInformation):
148
    """Revert the sign of GlobalMutualInformation."""
149
150
151
class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss):
152
    """
153
    Local squared zero-normalized cross-correlation.
154
155
    Denote y_true as t and y_pred as p. Consider a window having n elements.
156
    Each position in the window corresponds a weight w_i for i=1:n.
157
158
    Define the discrete expectation in the window E[t] as
159
160
        E[t] = sum_i(w_i * t_i) / sum_i(w_i)
161
162
    Similarly, the discrete variance in the window V[t] is
163
164
        V[t] = E[t**2] - E[t] ** 2
165
166
    The local squared zero-normalized cross-correlation is therefore
167
168
        E[ (t-E[t]) * (p-E[p]) ] ** 2 / V[t] / V[p]
169
170
    where the expectation in numerator is
171
172
        E[ (t-E[t]) * (p-E[p]) ] = E[t * p] - E[t] * E[p]
173
174
    Different kernel corresponds to different weights.
175
176
    For now, y_true and y_pred have to be at least 4d tensor, including batch axis.
177
178
    Reference:
179
180
        - Zero-normalized cross-correlation (ZNCC):
181
            https://en.wikipedia.org/wiki/Cross-correlation
182
        - Code: https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
183
    """
184
185
    kernel_fn_dict = dict(
186
        gaussian=gaussian_kernel1d,
187
        rectangular=rectangular_kernel1d,
188
        triangular=triangular_kernel1d,
189
    )
190
191
    def __init__(
192
        self,
193
        kernel_size: int = 9,
194
        kernel_type: str = "rectangular",
195
        smooth_nr: float = EPS,
196
        smooth_dr: float = EPS,
197
        reduction: str = tf.keras.losses.Reduction.SUM,
198
        name: str = "LocalNormalizedCrossCorrelation",
199
    ):
200
        """
201
        Init.
202
203
        :param kernel_size: int. Kernel size or kernel sigma for kernel_type='gauss'.
204
        :param kernel_type: str, rectangular, triangular or gaussian
205
        :param smooth_nr: small constant added to numerator in case of zero covariance.
206
        :param smooth_dr: small constant added to denominator in case of zero variance.
207
        :param reduction: using SUM reduction over batch axis,
208
            this is for supporting multi-device training,
209
            and the loss will be divided by global batch size,
210
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
211
        :param name: name of the loss
212
        """
213
        super().__init__(reduction=reduction, name=name)
214
        if kernel_type not in self.kernel_fn_dict.keys():
215
            raise ValueError(
216
                f"Wrong kernel_type {kernel_type} for LNCC loss type. "
217
                f"Feasible values are {self.kernel_fn_dict.keys()}"
218
            )
219
        self.kernel_fn = self.kernel_fn_dict[kernel_type]
220
        self.kernel_type = kernel_type
221
        self.kernel_size = kernel_size
222
        self.smooth_nr = smooth_nr
223
        self.smooth_dr = smooth_dr
224
225
        # (kernel_size, )
226
        self.kernel = self.kernel_fn(kernel_size=self.kernel_size)
227
        # E[1] = sum_i(w_i), ()
228
        self.kernel_vol = tf.reduce_sum(
229
            self.kernel[:, None, None]
230
            * self.kernel[None, :, None]
231
            * self.kernel[None, None, :]
232
        )
233
234
    def calc_ncc(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
235
        """
236
        Return NCC for a batch.
237
238
        The kernel should not be normalized, as normalizing them leads to computation
239
        with small values and the precision will be reduced.
240
        Here both numerator and denominator are actually multiplied by kernel volume,
241
        which helps the precision as well.
242
        However, when the variance is zero, the obtained value might be negative due to
243
        machine error. Therefore a hard-coded clipping is added to
244
        prevent division by zero.
245
246
        :param y_true: shape = (batch, dim1, dim2, dim3, 1)
247
        :param y_pred: shape = (batch, dim1, dim2, dim3, 1)
248
        :return: shape = (batch, dim1, dim2, dim3. 1)
249
        """
250
251
        # t = y_true, p = y_pred
252
        # (batch, dim1, dim2, dim3, 1)
253
        t2 = y_true * y_true
254
        p2 = y_pred * y_pred
255
        tp = y_true * y_pred
256
257
        # sum over kernel
258
        # (batch, dim1, dim2, dim3, 1)
259
        t_sum = separable_filter(y_true, kernel=self.kernel)  # E[t] * E[1]
260
        p_sum = separable_filter(y_pred, kernel=self.kernel)  # E[p] * E[1]
261
        t2_sum = separable_filter(t2, kernel=self.kernel)  # E[tt] * E[1]
262
        p2_sum = separable_filter(p2, kernel=self.kernel)  # E[pp] * E[1]
263
        tp_sum = separable_filter(tp, kernel=self.kernel)  # E[tp] * E[1]
264
265
        # average over kernel
266
        # (batch, dim1, dim2, dim3, 1)
267
        t_avg = t_sum / self.kernel_vol  # E[t]
268
        p_avg = p_sum / self.kernel_vol  # E[p]
269
270
        # shape = (batch, dim1, dim2, dim3, 1)
271
        cross = tp_sum - p_avg * t_sum  # E[tp] * E[1] - E[p] * E[t] * E[1]
272
        t_var = t2_sum - t_avg * t_sum  # V[t] * E[1]
273
        p_var = p2_sum - p_avg * p_sum  # V[p] * E[1]
274
275
        # ensure variance >= 0
276
        t_var = tf.maximum(t_var, 0)
277
        p_var = tf.maximum(p_var, 0)
278
279
        # (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p]
280
        ncc = (cross * cross + self.smooth_nr) / (t_var * p_var + self.smooth_dr)
281
282
        return ncc
283
284
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
285
        """
286
        Return loss for a batch.
287
288
        TODO: support channel axis dimension > 1.
289
290
        :param y_true: shape = (batch, dim1, dim2, dim3)
291
            or (batch, dim1, dim2, dim3, 1)
292
        :param y_pred: shape = (batch, dim1, dim2, dim3)
293
            or (batch, dim1, dim2, dim3, 1)
294
        :return: shape = (batch,)
295
        """
296
        # sanity checks
297
        if len(y_true.shape) == 4:
298
            y_true = tf.expand_dims(y_true, axis=4)
299
        if y_true.shape[4] != 1:
300
            raise ValueError(
301
                "Last dimension of y_true is not one. " f"y_true.shape = {y_true.shape}"
302
            )
303
        if len(y_pred.shape) == 4:
304
            y_pred = tf.expand_dims(y_pred, axis=4)
305
        if y_pred.shape[4] != 1:
306
            raise ValueError(
307
                "Last dimension of y_pred is not one. " f"y_pred.shape = {y_pred.shape}"
308
            )
309
310
        ncc = self.calc_ncc(y_true=y_true, y_pred=y_pred)
311
        return tf.reduce_mean(ncc, axis=[1, 2, 3, 4])
312
313
    def get_config(self) -> dict:
314
        """Return the config dictionary for recreating this class."""
315
        config = super().get_config()
316
        config.update(
317
            kernel_size=self.kernel_size,
318
            kernel_type=self.kernel_type,
319
            smooth_nr=self.smooth_nr,
320
            smooth_dr=self.smooth_dr,
321
        )
322
        return config
323
324
325
@REGISTRY.register_loss(name="lncc")
326
class LocalNormalizedCrossCorrelationLoss(
327
    NegativeLossMixin, LocalNormalizedCrossCorrelation
328
):
329
    """Revert the sign of LocalNormalizedCrossCorrelation."""
330
331
332
class GlobalNormalizedCrossCorrelation(tf.keras.losses.Loss):
333
    """
334
    Global squared zero-normalized cross-correlation.
335
336
    Compute the squared cross-correlation between the reference and moving images
337
    y_true and y_pred have to be at least 4d tensor, including batch axis.
338
339
    Reference:
340
341
        - Zero-normalized cross-correlation (ZNCC):
342
            https://en.wikipedia.org/wiki/Cross-correlation
343
344
    """
345
346
    def __init__(
347
        self,
348
        reduction: str = tf.keras.losses.Reduction.AUTO,
349
        name: str = "GlobalNormalizedCrossCorrelation",
350
    ):
351
        """
352
        Init.
353
        :param reduction: using AUTO reduction,
354
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
355
        :param name: name of the loss
356
        """
357
        super().__init__(reduction=reduction, name=name)
358
359
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
360
        """
361
        Return loss for a batch.
362
363
        :param y_true: shape = (batch, ...)
364
        :param y_pred: shape = (batch, ...)
365
        :return: shape = (batch,)
366
        """
367
368
        axis = [a for a in range(1, len(y_true.shape))]
369
        mu_pred = tf.reduce_mean(y_pred, axis=axis, keepdims=True)
370
        mu_true = tf.reduce_mean(y_true, axis=axis, keepdims=True)
371
        var_pred = tf.math.reduce_variance(y_pred, axis=axis)
372
        var_true = tf.math.reduce_variance(y_true, axis=axis)
373
        numerator = tf.abs(
374
            tf.reduce_mean((y_pred - mu_pred) * (y_true - mu_true), axis=axis)
375
        )
376
377
        return (numerator * numerator + EPS) / (var_pred * var_true + EPS)
378
379
380
@REGISTRY.register_loss(name="gncc")
381
class GlobalNormalizedCrossCorrelationLoss(
382
    NegativeLossMixin, GlobalNormalizedCrossCorrelation
383
):
384
    """Revert the sign of GlobalNormalizedCrossCorrelation."""
385