LocalNormalizedCrossCorrelation.get_config()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 8
dl 0
loc 10
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Provide different loss or metrics classes for images."""
2
import tensorflow as tf
3
4
from deepreg.constant import EPS
5
from deepreg.loss.kernel import gaussian_kernel1d_size as gaussian_kernel1d
6
from deepreg.loss.kernel import rectangular_kernel1d, triangular_kernel1d
7
from deepreg.loss.util import NegativeLossMixin, separable_filter
8
from deepreg.registry import REGISTRY
9
10
11
class GlobalMutualInformation(tf.keras.losses.Loss):
12
    """
13
    Differentiable global mutual information via Parzen windowing method.
14
15
    y_true and y_pred have to be at least 4d tensor, including batch axis.
16
17
    Reference: https://dspace.mit.edu/handle/1721.1/123142,
18
        Section 3.1, equation 3.1-3.5, Algorithm 1
19
    """
20
21
    def __init__(
22
        self,
23
        num_bins: int = 23,
24
        sigma_ratio: float = 0.5,
25
        name: str = "GlobalMutualInformation",
26
        **kwargs,
27
    ):
28
        """
29
        Init.
30
31
        :param num_bins: number of bins for intensity, the default value is empirical.
32
        :param sigma_ratio: a hyper param for gaussian function
33
        :param name: name of the loss
34
        :param kwargs: additional arguments.
35
        """
36
        super().__init__(name=name, **kwargs)
37
        self.num_bins = num_bins
38
        self.sigma_ratio = sigma_ratio
39
40
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
41
        """
42
        Return loss for a batch.
43
44
        :param y_true: shape = (batch, dim1, dim2, dim3)
45
            or (batch, dim1, dim2, dim3, ch)
46
        :param y_pred: shape = (batch, dim1, dim2, dim3)
47
            or (batch, dim1, dim2, dim3, ch)
48
        :return: shape = (batch,)
49
        """
50
        # adjust
51
        if len(y_true.shape) == 4:
52
            y_true = tf.expand_dims(y_true, axis=4)
53
            y_pred = tf.expand_dims(y_pred, axis=4)
54
        assert len(y_true.shape) == len(y_pred.shape) == 5
55
56
        # intensity is split into bins between 0, 1
57
        y_true = tf.clip_by_value(y_true, 0, 1)
58
        y_pred = tf.clip_by_value(y_pred, 0, 1)
59
        bin_centers = tf.linspace(0.0, 1.0, self.num_bins)  # (num_bins,)
60
        bin_centers = tf.cast(bin_centers, dtype=y_true.dtype)
61
        bin_centers = bin_centers[None, None, ...]  # (1, 1, num_bins)
62
        sigma = (
63
            tf.reduce_mean(bin_centers[:, :, 1:] - bin_centers[:, :, :-1])
64
            * self.sigma_ratio
65
        )  # scalar, sigma in the Gaussian function (weighting function W)
66
        preterm = 1 / (2 * tf.math.square(sigma))  # scalar
67
        batch, w, h, z, c = y_true.shape
68
        y_true = tf.reshape(y_true, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
69
        y_pred = tf.reshape(y_pred, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
70
        nb_voxels = y_true.shape[1] * 1.0  # w * h * z, number of voxels
71
72
        # each voxel contributes continuously to a range of histogram bin
73
        ia = tf.math.exp(
74
            -preterm * tf.math.square(y_true - bin_centers)
75
        )  # (batch, nb_voxels, num_bins)
76
        ia /= tf.reduce_sum(ia, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
77
        ia = tf.transpose(ia, (0, 2, 1))  # (batch, num_bins, nb_voxels)
78
        pa = tf.reduce_mean(ia, axis=-1, keepdims=True)  # (batch, num_bins, 1)
79
80
        ib = tf.math.exp(
81
            -preterm * tf.math.square(y_pred - bin_centers)
82
        )  # (batch, nb_voxels, num_bins)
83
        ib /= tf.reduce_sum(ib, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
84
        pb = tf.reduce_mean(ib, axis=1, keepdims=True)  # (batch, 1, num_bins)
85
86
        papb = tf.matmul(pa, pb)  # (batch, num_bins, num_bins)
87
        pab = tf.matmul(ia, ib)  # (batch, num_bins, num_bins)
88
        pab /= nb_voxels
89
90
        # MI: sum(P_ab * log(P_ab/P_ap_b))
91
        div = (pab + EPS) / (papb + EPS)
92
        return tf.reduce_sum(pab * tf.math.log(div + EPS), axis=[1, 2])
93
94
    def get_config(self) -> dict:
95
        """Return the config dictionary for recreating this class."""
96
        config = super().get_config()
97
        config["num_bins"] = self.num_bins
98
        config["sigma_ratio"] = self.sigma_ratio
99
        return config
100
101
102
@REGISTRY.register_loss(name="gmi")
103
class GlobalMutualInformationLoss(NegativeLossMixin, GlobalMutualInformation):
104
    """Revert the sign of GlobalMutualInformation."""
105
106
107
class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss):
108
    """
109
    Local squared zero-normalized cross-correlation.
110
111
    Denote y_true as t and y_pred as p. Consider a window having n elements.
112
    Each position in the window corresponds a weight w_i for i=1:n.
113
114
    Define the discrete expectation in the window E[t] as
115
116
        E[t] = sum_i(w_i * t_i) / sum_i(w_i)
117
118
    Similarly, the discrete variance in the window V[t] is
119
120
        V[t] = E[t**2] - E[t] ** 2
121
122
    The local squared zero-normalized cross-correlation is therefore
123
124
        E[ (t-E[t]) * (p-E[p]) ] ** 2 / V[t] / V[p]
125
126
    where the expectation in numerator is
127
128
        E[ (t-E[t]) * (p-E[p]) ] = E[t * p] - E[t] * E[p]
129
130
    Different kernel corresponds to different weights.
131
132
    For now, y_true and y_pred have to be at least 4d tensor, including batch axis.
133
134
    Reference:
135
136
        - Zero-normalized cross-correlation (ZNCC):
137
            https://en.wikipedia.org/wiki/Cross-correlation
138
        - Code: https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
139
    """
140
141
    kernel_fn_dict = dict(
142
        gaussian=gaussian_kernel1d,
143
        rectangular=rectangular_kernel1d,
144
        triangular=triangular_kernel1d,
145
    )
146
147
    def __init__(
148
        self,
149
        kernel_size: int = 9,
150
        kernel_type: str = "rectangular",
151
        smooth_nr: float = EPS,
152
        smooth_dr: float = EPS,
153
        name: str = "LocalNormalizedCrossCorrelation",
154
        **kwargs,
155
    ):
156
        """
157
        Init.
158
159
        :param kernel_size: int. Kernel size or kernel sigma for kernel_type='gauss'.
160
        :param kernel_type: str, rectangular, triangular or gaussian
161
        :param smooth_nr: small constant added to numerator in case of zero covariance.
162
        :param smooth_dr: small constant added to denominator in case of zero variance.
163
        :param name: name of the loss.
164
        :param kwargs: additional arguments.
165
        """
166
        super().__init__(name=name, **kwargs)
167
        if kernel_type not in self.kernel_fn_dict.keys():
168
            raise ValueError(
169
                f"Wrong kernel_type {kernel_type} for LNCC loss type. "
170
                f"Feasible values are {self.kernel_fn_dict.keys()}"
171
            )
172
        self.kernel_fn = self.kernel_fn_dict[kernel_type]
173
        self.kernel_type = kernel_type
174
        self.kernel_size = kernel_size
175
        self.smooth_nr = smooth_nr
176
        self.smooth_dr = smooth_dr
177
178
        # (kernel_size, )
179
        self.kernel = self.kernel_fn(kernel_size=self.kernel_size)
180
        # E[1] = sum_i(w_i), ()
181
        self.kernel_vol = tf.reduce_sum(
182
            self.kernel[:, None, None]
183
            * self.kernel[None, :, None]
184
            * self.kernel[None, None, :]
185
        )
186
187
    def calc_ncc(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
188
        """
189
        Return NCC for a batch.
190
191
        The kernel should not be normalized, as normalizing them leads to computation
192
        with small values and the precision will be reduced.
193
        Here both numerator and denominator are actually multiplied by kernel volume,
194
        which helps the precision as well.
195
        However, when the variance is zero, the obtained value might be negative due to
196
        machine error. Therefore a hard-coded clipping is added to
197
        prevent division by zero.
198
199
        :param y_true: shape = (batch, dim1, dim2, dim3, 1)
200
        :param y_pred: shape = (batch, dim1, dim2, dim3, 1)
201
        :return: shape = (batch, dim1, dim2, dim3. 1)
202
        """
203
204
        # t = y_true, p = y_pred
205
        # (batch, dim1, dim2, dim3, 1)
206
        t2 = y_true * y_true
207
        p2 = y_pred * y_pred
208
        tp = y_true * y_pred
209
210
        # sum over kernel
211
        # (batch, dim1, dim2, dim3, 1)
212
        t_sum = separable_filter(y_true, kernel=self.kernel)  # E[t] * E[1]
213
        p_sum = separable_filter(y_pred, kernel=self.kernel)  # E[p] * E[1]
214
        t2_sum = separable_filter(t2, kernel=self.kernel)  # E[tt] * E[1]
215
        p2_sum = separable_filter(p2, kernel=self.kernel)  # E[pp] * E[1]
216
        tp_sum = separable_filter(tp, kernel=self.kernel)  # E[tp] * E[1]
217
218
        # average over kernel
219
        # (batch, dim1, dim2, dim3, 1)
220
        t_avg = t_sum / self.kernel_vol  # E[t]
221
        p_avg = p_sum / self.kernel_vol  # E[p]
222
223
        # shape = (batch, dim1, dim2, dim3, 1)
224
        cross = tp_sum - p_avg * t_sum  # E[tp] * E[1] - E[p] * E[t] * E[1]
225
        t_var = t2_sum - t_avg * t_sum  # V[t] * E[1]
226
        p_var = p2_sum - p_avg * p_sum  # V[p] * E[1]
227
228
        # ensure variance >= 0
229
        t_var = tf.maximum(t_var, 0)
230
        p_var = tf.maximum(p_var, 0)
231
232
        # (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p]
233
        ncc = (cross * cross + self.smooth_nr) / (t_var * p_var + self.smooth_dr)
234
235
        return ncc
236
237
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
238
        """
239
        Return loss for a batch.
240
241
        TODO: support channel axis dimension > 1.
242
243
        :param y_true: shape = (batch, dim1, dim2, dim3)
244
            or (batch, dim1, dim2, dim3, 1)
245
        :param y_pred: shape = (batch, dim1, dim2, dim3)
246
            or (batch, dim1, dim2, dim3, 1)
247
        :return: shape = (batch,)
248
        """
249
        # sanity checks
250
        if len(y_true.shape) == 4:
251
            y_true = tf.expand_dims(y_true, axis=4)
252
        if y_true.shape[4] != 1:
253
            raise ValueError(
254
                "Last dimension of y_true is not one. " f"y_true.shape = {y_true.shape}"
255
            )
256
        if len(y_pred.shape) == 4:
257
            y_pred = tf.expand_dims(y_pred, axis=4)
258
        if y_pred.shape[4] != 1:
259
            raise ValueError(
260
                "Last dimension of y_pred is not one. " f"y_pred.shape = {y_pred.shape}"
261
            )
262
263
        ncc = self.calc_ncc(y_true=y_true, y_pred=y_pred)
264
        return tf.reduce_mean(ncc, axis=[1, 2, 3, 4])
265
266
    def get_config(self) -> dict:
267
        """Return the config dictionary for recreating this class."""
268
        config = super().get_config()
269
        config.update(
270
            kernel_size=self.kernel_size,
271
            kernel_type=self.kernel_type,
272
            smooth_nr=self.smooth_nr,
273
            smooth_dr=self.smooth_dr,
274
        )
275
        return config
276
277
278
@REGISTRY.register_loss(name="lncc")
279
class LocalNormalizedCrossCorrelationLoss(
280
    NegativeLossMixin, LocalNormalizedCrossCorrelation
281
):
282
    """Revert the sign of LocalNormalizedCrossCorrelation."""
283
284
285
class GlobalNormalizedCrossCorrelation(tf.keras.losses.Loss):
286
    """
287
    Global squared zero-normalized cross-correlation.
288
289
    Compute the squared cross-correlation between the reference and moving images
290
    y_true and y_pred have to be at least 4d tensor, including batch axis.
291
292
    Reference:
293
294
        - Zero-normalized cross-correlation (ZNCC):
295
            https://en.wikipedia.org/wiki/Cross-correlation
296
297
    """
298
299
    def __init__(
300
        self,
301
        name: str = "GlobalNormalizedCrossCorrelation",
302
        **kwargs,
303
    ):
304
        """
305
        Init.
306
307
        :param name: name of the loss
308
        :param kwargs: additional arguments.
309
        """
310
        super().__init__(name=name, **kwargs)
311
312
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
313
        """
314
        Return loss for a batch.
315
316
        :param y_true: shape = (batch, ...)
317
        :param y_pred: shape = (batch, ...)
318
        :return: shape = (batch,)
319
        """
320
321
        axis = [a for a in range(1, len(y_true.shape))]
322
        mu_pred = tf.reduce_mean(y_pred, axis=axis, keepdims=True)
323
        mu_true = tf.reduce_mean(y_true, axis=axis, keepdims=True)
324
        var_pred = tf.math.reduce_variance(y_pred, axis=axis)
325
        var_true = tf.math.reduce_variance(y_true, axis=axis)
326
        numerator = tf.abs(
327
            tf.reduce_mean((y_pred - mu_pred) * (y_true - mu_true), axis=axis)
328
        )
329
330
        return (numerator * numerator + EPS) / (var_pred * var_true + EPS)
331
332
333
@REGISTRY.register_loss(name="gncc")
334
class GlobalNormalizedCrossCorrelationLoss(
335
    NegativeLossMixin, GlobalNormalizedCrossCorrelation
336
):
337
    """Revert the sign of GlobalNormalizedCrossCorrelation."""
338