Passed
Pull Request — main (#733)
by Yunguan
01:25
created

deepreg.loss.image   A

Complexity

Total Complexity 17

Size/Duplication

Total Lines 371
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 17
eloc 157
dl 0
loc 371
rs 10
c 0
b 0
f 0

11 Methods

Rating   Name   Duplication   Size   Complexity  
A GlobalMutualInformation.get_config() 0 6 1
A SumSquaredDifference.call() 0 11 1
A SumSquaredDifference.__init__() 0 13 1
A GlobalMutualInformation.__init__() 0 19 1
A GlobalMutualInformation.call() 0 53 2
A LocalNormalizedCrossCorrelation.calc_ncc() 0 41 1
A GlobalNormalizedCrossCorrelation.call() 0 19 1
A LocalNormalizedCrossCorrelation.__init__() 0 39 2
A GlobalNormalizedCrossCorrelation.__init__() 0 12 1
A LocalNormalizedCrossCorrelation.call() 0 28 5
A LocalNormalizedCrossCorrelation.get_config() 0 10 1
1
"""Provide different loss or metrics classes for images."""
2
import tensorflow as tf
3
4
from deepreg.constant import EPS
5
from deepreg.loss.util import NegativeLossMixin
6
from deepreg.loss.util import gaussian_kernel1d_size as gaussian_kernel1d
7
from deepreg.loss.util import (
8
    rectangular_kernel1d,
9
    separable_filter,
10
    triangular_kernel1d,
11
)
12
from deepreg.registry import REGISTRY
13
14
15
@REGISTRY.register_loss(name="ssd")
16
class SumSquaredDifference(tf.keras.losses.Loss):
17
    """
18
    Sum of squared distance between y_true and y_pred.
19
20
    y_true and y_pred have to be at least 1d tensor, including batch axis.
21
    """
22
23
    def __init__(
24
        self,
25
        reduction: str = tf.keras.losses.Reduction.SUM,
26
        name: str = "SumSquaredDifference",
27
    ):
28
        """
29
        Init.
30
31
        :param reduction: using SUM reduction over batch axis,
32
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
33
        :param name: name of the loss
34
        """
35
        super().__init__(reduction=reduction, name=name)
36
37
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
38
        """
39
        Return loss for a batch.
40
41
        :param y_true: shape = (batch, ...)
42
        :param y_pred: shape = (batch, ...)
43
        :return: shape = (batch,)
44
        """
45
        loss = tf.math.squared_difference(y_true, y_pred)
46
        loss = tf.keras.layers.Flatten()(loss)
47
        return tf.reduce_mean(loss, axis=1)
48
49
50
class GlobalMutualInformation(tf.keras.losses.Loss):
51
    """
52
    Differentiable global mutual information via Parzen windowing method.
53
54
    y_true and y_pred have to be at least 4d tensor, including batch axis.
55
56
    Reference: https://dspace.mit.edu/handle/1721.1/123142,
57
        Section 3.1, equation 3.1-3.5, Algorithm 1
58
    """
59
60
    def __init__(
61
        self,
62
        num_bins: int = 23,
63
        sigma_ratio: float = 0.5,
64
        reduction: str = tf.keras.losses.Reduction.SUM,
65
        name: str = "GlobalMutualInformation",
66
    ):
67
        """
68
        Init.
69
70
        :param num_bins: number of bins for intensity, the default value is empirical.
71
        :param sigma_ratio: a hyper param for gaussian function
72
        :param reduction: using SUM reduction over batch axis,
73
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
74
        :param name: name of the loss
75
        """
76
        super().__init__(reduction=reduction, name=name)
77
        self.num_bins = num_bins
78
        self.sigma_ratio = sigma_ratio
79
80
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
81
        """
82
        Return loss for a batch.
83
84
        :param y_true: shape = (batch, dim1, dim2, dim3)
85
            or (batch, dim1, dim2, dim3, ch)
86
        :param y_pred: shape = (batch, dim1, dim2, dim3)
87
            or (batch, dim1, dim2, dim3, ch)
88
        :return: shape = (batch,)
89
        """
90
        # adjust
91
        if len(y_true.shape) == 4:
92
            y_true = tf.expand_dims(y_true, axis=4)
93
            y_pred = tf.expand_dims(y_pred, axis=4)
94
        assert len(y_true.shape) == len(y_pred.shape) == 5
95
96
        # intensity is split into bins between 0, 1
97
        y_true = tf.clip_by_value(y_true, 0, 1)
98
        y_pred = tf.clip_by_value(y_pred, 0, 1)
99
        bin_centers = tf.linspace(0.0, 1.0, self.num_bins)  # (num_bins,)
100
        bin_centers = tf.cast(bin_centers, dtype=y_true.dtype)
101
        bin_centers = bin_centers[None, None, ...]  # (1, 1, num_bins)
102
        sigma = (
103
            tf.reduce_mean(bin_centers[:, :, 1:] - bin_centers[:, :, :-1])
104
            * self.sigma_ratio
105
        )  # scalar, sigma in the Gaussian function (weighting function W)
106
        preterm = 1 / (2 * tf.math.square(sigma))  # scalar
107
        batch, w, h, z, c = y_true.shape
108
        y_true = tf.reshape(y_true, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
109
        y_pred = tf.reshape(y_pred, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
110
        nb_voxels = y_true.shape[1] * 1.0  # w * h * z, number of voxels
111
112
        # each voxel contributes continuously to a range of histogram bin
113
        ia = tf.math.exp(
114
            -preterm * tf.math.square(y_true - bin_centers)
115
        )  # (batch, nb_voxels, num_bins)
116
        ia /= tf.reduce_sum(ia, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
117
        ia = tf.transpose(ia, (0, 2, 1))  # (batch, num_bins, nb_voxels)
118
        pa = tf.reduce_mean(ia, axis=-1, keepdims=True)  # (batch, num_bins, 1)
119
120
        ib = tf.math.exp(
121
            -preterm * tf.math.square(y_pred - bin_centers)
122
        )  # (batch, nb_voxels, num_bins)
123
        ib /= tf.reduce_sum(ib, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
124
        pb = tf.reduce_mean(ib, axis=1, keepdims=True)  # (batch, 1, num_bins)
125
126
        papb = tf.matmul(pa, pb)  # (batch, num_bins, num_bins)
127
        pab = tf.matmul(ia, ib)  # (batch, num_bins, num_bins)
128
        pab /= nb_voxels
129
130
        # MI: sum(P_ab * log(P_ab/P_ap_b))
131
        div = (pab + EPS) / (papb + EPS)
132
        return tf.reduce_sum(pab * tf.math.log(div + EPS), axis=[1, 2])
133
134
    def get_config(self) -> dict:
135
        """Return the config dictionary for recreating this class."""
136
        config = super().get_config()
137
        config["num_bins"] = self.num_bins
138
        config["sigma_ratio"] = self.sigma_ratio
139
        return config
140
141
142
@REGISTRY.register_loss(name="gmi")
143
class GlobalMutualInformationLoss(NegativeLossMixin, GlobalMutualInformation):
144
    """Revert the sign of GlobalMutualInformation."""
145
146
147
class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss):
148
    """
149
    Local squared zero-normalized cross-correlation.
150
151
    Denote y_true as t and y_pred as p. Consider a window having n elements.
152
    Each position in the window corresponds a weight w_i for i=1:n.
153
154
    Define the discrete expectation in the window E[t] as
155
156
        E[t] = sum_i(w_i * t_i) / sum_i(w_i)
157
158
    Similarly, the discrete variance in the window V[t] is
159
160
        V[t] = E[t**2] - E[t] ** 2
161
162
    The local squared zero-normalized cross-correlation is therefore
163
164
        E[ (t-E[t]) * (p-E[p]) ] ** 2 / V[t] / V[p]
165
166
    where the expectation in numerator is
167
168
        E[ (t-E[t]) * (p-E[p]) ] = E[t * p] - E[t] * E[p]
169
170
    Different kernel corresponds to different weights.
171
172
    For now, y_true and y_pred have to be at least 4d tensor, including batch axis.
173
174
    Reference:
175
176
        - Zero-normalized cross-correlation (ZNCC):
177
            https://en.wikipedia.org/wiki/Cross-correlation
178
        - Code: https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
179
    """
180
181
    kernel_fn_dict = dict(
182
        gaussian=gaussian_kernel1d,
183
        rectangular=rectangular_kernel1d,
184
        triangular=triangular_kernel1d,
185
    )
186
187
    def __init__(
188
        self,
189
        kernel_size: int = 9,
190
        kernel_type: str = "rectangular",
191
        smooth_nr: float = EPS,
192
        smooth_dr: float = EPS,
193
        reduction: str = tf.keras.losses.Reduction.SUM,
194
        name: str = "LocalNormalizedCrossCorrelation",
195
    ):
196
        """
197
        Init.
198
199
        :param kernel_size: int. Kernel size or kernel sigma for kernel_type='gauss'.
200
        :param kernel_type: str, rectangular, triangular or gaussian
201
        :param smooth_nr: small constant added to numerator in case of zero covariance.
202
        :param smooth_dr: small constant added to denominator in case of zero variance.
203
        :param reduction: using SUM reduction over batch axis,
204
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
205
        :param name: name of the loss
206
        """
207
        super().__init__(reduction=reduction, name=name)
208
        if kernel_type not in self.kernel_fn_dict.keys():
209
            raise ValueError(
210
                f"Wrong kernel_type {kernel_type} for LNCC loss type. "
211
                f"Feasible values are {self.kernel_fn_dict.keys()}"
212
            )
213
        self.kernel_fn = self.kernel_fn_dict[kernel_type]
214
        self.kernel_type = kernel_type
215
        self.kernel_size = kernel_size
216
        self.smooth_nr = smooth_nr
217
        self.smooth_dr = smooth_dr
218
219
        # (kernel_size, )
220
        self.kernel = self.kernel_fn(kernel_size=self.kernel_size)
221
        # E[1] = sum_i(w_i), ()
222
        self.kernel_vol = tf.reduce_sum(
223
            self.kernel[:, None, None]
224
            * self.kernel[None, :, None]
225
            * self.kernel[None, None, :]
226
        )
227
228
    def calc_ncc(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
229
        """
230
        Return NCC for a batch.
231
232
        :param y_true: shape = (batch, dim1, dim2, dim3, 1)
233
        :param y_pred: shape = (batch, dim1, dim2, dim3, 1)
234
        :return: shape = (batch, dim1, dim2, dim3. 1)
235
        """
236
237
        # t = y_true, p = y_pred
238
        # (batch, dim1, dim2, dim3, 1)
239
        t2 = y_true * y_true
240
        p2 = y_pred * y_pred
241
        tp = y_true * y_pred
242
243
        # sum over kernel
244
        # (batch, dim1, dim2, dim3, 1)
245
        t_sum = separable_filter(y_true, kernel=self.kernel)  # E[t] * E[1]
246
        p_sum = separable_filter(y_pred, kernel=self.kernel)  # E[p] * E[1]
247
        t2_sum = separable_filter(t2, kernel=self.kernel)  # E[tt] * E[1]
248
        p2_sum = separable_filter(p2, kernel=self.kernel)  # E[pp] * E[1]
249
        tp_sum = separable_filter(tp, kernel=self.kernel)  # E[tp] * E[1]
250
251
        # average over kernel
252
        # (batch, dim1, dim2, dim3, 1)
253
        t_avg = t_sum / self.kernel_vol  # E[t]
254
        p_avg = p_sum / self.kernel_vol  # E[p]
255
256
        # shape = (batch, dim1, dim2, dim3, 1)
257
        cross = tp_sum - p_avg * t_sum  # E[tp] * E[1] - E[p] * E[t] * E[1]
258
        t_var = t2_sum - t_avg * t_sum  # V[t] * E[1]
259
        p_var = p2_sum - p_avg * p_sum  # V[p] * E[1]
260
261
        # ensure variance >= 0
262
        t_var = tf.maximum(t_var, 0)
263
        p_var = tf.maximum(p_var, 0)
264
265
        # (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p]
266
        ncc = (cross * cross + self.smooth_nr) / (t_var * p_var + self.smooth_dr)
267
268
        return ncc
269
270
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
271
        """
272
        Return loss for a batch.
273
274
        TODO: support channel axis dimension > 1.
275
276
        :param y_true: shape = (batch, dim1, dim2, dim3)
277
            or (batch, dim1, dim2, dim3, 1)
278
        :param y_pred: shape = (batch, dim1, dim2, dim3)
279
            or (batch, dim1, dim2, dim3, 1)
280
        :return: shape = (batch,)
281
        """
282
        # sanity checks
283
        if len(y_true.shape) == 4:
284
            y_true = tf.expand_dims(y_true, axis=4)
285
        if y_true.shape[4] != 1:
286
            raise ValueError(
287
                "Last dimension of y_true is not one. " f"y_true.shape = {y_true.shape}"
288
            )
289
        if len(y_pred.shape) == 4:
290
            y_pred = tf.expand_dims(y_pred, axis=4)
291
        if y_pred.shape[4] != 1:
292
            raise ValueError(
293
                "Last dimension of y_pred is not one. " f"y_pred.shape = {y_pred.shape}"
294
            )
295
296
        ncc = self.calc_ncc(y_true=y_true, y_pred=y_pred)
297
        return tf.reduce_mean(ncc, axis=[1, 2, 3, 4])
298
299
    def get_config(self) -> dict:
300
        """Return the config dictionary for recreating this class."""
301
        config = super().get_config()
302
        config.update(
303
            kernel_size=self.kernel_size,
304
            kernel_type=self.kernel_type,
305
            smooth_nr=self.smooth_nr,
306
            smooth_dr=self.smooth_dr,
307
        )
308
        return config
309
310
311
@REGISTRY.register_loss(name="lncc")
312
class LocalNormalizedCrossCorrelationLoss(
313
    NegativeLossMixin, LocalNormalizedCrossCorrelation
314
):
315
    """Revert the sign of LocalNormalizedCrossCorrelation."""
316
317
318
class GlobalNormalizedCrossCorrelation(tf.keras.losses.Loss):
319
    """
320
    Global squared zero-normalized cross-correlation.
321
322
    Compute the squared cross-correlation between the reference and moving images
323
    y_true and y_pred have to be at least 4d tensor, including batch axis.
324
325
    Reference:
326
327
        - Zero-normalized cross-correlation (ZNCC):
328
            https://en.wikipedia.org/wiki/Cross-correlation
329
330
    """
331
332
    def __init__(
333
        self,
334
        reduction: str = tf.keras.losses.Reduction.AUTO,
335
        name: str = "GlobalNormalizedCrossCorrelation",
336
    ):
337
        """
338
        Init.
339
        :param reduction: using AUTO reduction,
340
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
341
        :param name: name of the loss
342
        """
343
        super().__init__(reduction=reduction, name=name)
344
345
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
346
        """
347
        Return loss for a batch.
348
349
        :param y_true: shape = (batch, ...)
350
        :param y_pred: shape = (batch, ...)
351
        :return: shape = (batch,)
352
        """
353
354
        axis = [a for a in range(1, len(y_true.shape))]
355
        mu_pred = tf.reduce_mean(y_pred, axis=axis, keepdims=True)
356
        mu_true = tf.reduce_mean(y_true, axis=axis, keepdims=True)
357
        var_pred = tf.math.reduce_variance(y_pred, axis=axis)
358
        var_true = tf.math.reduce_variance(y_true, axis=axis)
359
        numerator = tf.abs(
360
            tf.reduce_mean((y_pred - mu_pred) * (y_true - mu_true), axis=axis)
361
        )
362
363
        return (numerator * numerator + EPS) / (var_pred * var_true + EPS)
364
365
366
@REGISTRY.register_loss(name="gncc")
367
class GlobalNormalizedCrossCorrelationLoss(
368
    NegativeLossMixin, GlobalNormalizedCrossCorrelation
369
):
370
    """Revert the sign of GlobalNormalizedCrossCorrelation."""
371