Completed
Push — main ( bd3202...72b597 )
by Yunguan
27s queued 13s
created

deepreg.loss.image   A

Complexity

Total Complexity 13

Size/Duplication

Total Lines 336
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 13
eloc 139
dl 0
loc 336
rs 10
c 0
b 0
f 0

10 Methods

Rating   Name   Duplication   Size   Complexity  
A GlobalMutualInformation.get_config() 0 6 1
A SumSquaredDifference.call() 0 11 1
A LocalNormalizedCrossCorrelation.__init__() 0 33 2
A SumSquaredDifference.__init__() 0 13 1
A GlobalMutualInformation.__init__() 0 19 1
A LocalNormalizedCrossCorrelation.call() 0 44 2
A GlobalMutualInformation.call() 0 53 2
A LocalNormalizedCrossCorrelation.get_config() 0 6 1
A GlobalNormalizedCrossCorrelation.call() 0 19 1
A GlobalNormalizedCrossCorrelation.__init__() 0 12 1
1
"""Provide different loss or metrics classes for images."""
2
import tensorflow as tf
3
4
from deepreg.loss.util import NegativeLossMixin
5
from deepreg.loss.util import gaussian_kernel1d_size as gaussian_kernel1d
6
from deepreg.loss.util import (
7
    rectangular_kernel1d,
8
    separable_filter,
9
    triangular_kernel1d,
10
)
11
from deepreg.registry import REGISTRY
12
13
EPS = tf.keras.backend.epsilon()
14
15
16
@REGISTRY.register_loss(name="ssd")
17
class SumSquaredDifference(tf.keras.losses.Loss):
18
    """
19
    Sum of squared distance between y_true and y_pred.
20
21
    y_true and y_pred have to be at least 1d tensor, including batch axis.
22
    """
23
24
    def __init__(
25
        self,
26
        reduction: str = tf.keras.losses.Reduction.SUM,
27
        name: str = "SumSquaredDifference",
28
    ):
29
        """
30
        Init.
31
32
        :param reduction: using SUM reduction over batch axis,
33
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
34
        :param name: name of the loss
35
        """
36
        super().__init__(reduction=reduction, name=name)
37
38
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
39
        """
40
        Return loss for a batch.
41
42
        :param y_true: shape = (batch, ...)
43
        :param y_pred: shape = (batch, ...)
44
        :return: shape = (batch,)
45
        """
46
        loss = tf.math.squared_difference(y_true, y_pred)
47
        loss = tf.keras.layers.Flatten()(loss)
48
        return tf.reduce_mean(loss, axis=1)
49
50
51
class GlobalMutualInformation(tf.keras.losses.Loss):
52
    """
53
    Differentiable global mutual information via Parzen windowing method.
54
55
    y_true and y_pred have to be at least 4d tensor, including batch axis.
56
57
    Reference: https://dspace.mit.edu/handle/1721.1/123142,
58
        Section 3.1, equation 3.1-3.5, Algorithm 1
59
    """
60
61
    def __init__(
62
        self,
63
        num_bins: int = 23,
64
        sigma_ratio: float = 0.5,
65
        reduction: str = tf.keras.losses.Reduction.SUM,
66
        name: str = "GlobalMutualInformation",
67
    ):
68
        """
69
        Init.
70
71
        :param num_bins: number of bins for intensity, the default value is empirical.
72
        :param sigma_ratio: a hyper param for gaussian function
73
        :param reduction: using SUM reduction over batch axis,
74
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
75
        :param name: name of the loss
76
        """
77
        super().__init__(reduction=reduction, name=name)
78
        self.num_bins = num_bins
79
        self.sigma_ratio = sigma_ratio
80
81
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
82
        """
83
        Return loss for a batch.
84
85
        :param y_true: shape = (batch, dim1, dim2, dim3)
86
            or (batch, dim1, dim2, dim3, ch)
87
        :param y_pred: shape = (batch, dim1, dim2, dim3)
88
            or (batch, dim1, dim2, dim3, ch)
89
        :return: shape = (batch,)
90
        """
91
        # adjust
92
        if len(y_true.shape) == 4:
93
            y_true = tf.expand_dims(y_true, axis=4)
94
            y_pred = tf.expand_dims(y_pred, axis=4)
95
        assert len(y_true.shape) == len(y_pred.shape) == 5
96
97
        # intensity is split into bins between 0, 1
98
        y_true = tf.clip_by_value(y_true, 0, 1)
99
        y_pred = tf.clip_by_value(y_pred, 0, 1)
100
        bin_centers = tf.linspace(0.0, 1.0, self.num_bins)  # (num_bins,)
101
        bin_centers = tf.cast(bin_centers, dtype=y_true.dtype)
102
        bin_centers = bin_centers[None, None, ...]  # (1, 1, num_bins)
103
        sigma = (
104
            tf.reduce_mean(bin_centers[:, :, 1:] - bin_centers[:, :, :-1])
105
            * self.sigma_ratio
106
        )  # scalar, sigma in the Gaussian function (weighting function W)
107
        preterm = 1 / (2 * tf.math.square(sigma))  # scalar
108
        batch, w, h, z, c = y_true.shape
109
        y_true = tf.reshape(y_true, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
110
        y_pred = tf.reshape(y_pred, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
111
        nb_voxels = y_true.shape[1] * 1.0  # w * h * z, number of voxels
112
113
        # each voxel contributes continuously to a range of histogram bin
114
        ia = tf.math.exp(
115
            -preterm * tf.math.square(y_true - bin_centers)
116
        )  # (batch, nb_voxels, num_bins)
117
        ia /= tf.reduce_sum(ia, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
118
        ia = tf.transpose(ia, (0, 2, 1))  # (batch, num_bins, nb_voxels)
119
        pa = tf.reduce_mean(ia, axis=-1, keepdims=True)  # (batch, num_bins, 1)
120
121
        ib = tf.math.exp(
122
            -preterm * tf.math.square(y_pred - bin_centers)
123
        )  # (batch, nb_voxels, num_bins)
124
        ib /= tf.reduce_sum(ib, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
125
        pb = tf.reduce_mean(ib, axis=1, keepdims=True)  # (batch, 1, num_bins)
126
127
        papb = tf.matmul(pa, pb)  # (batch, num_bins, num_bins)
128
        pab = tf.matmul(ia, ib)  # (batch, num_bins, num_bins)
129
        pab /= nb_voxels
130
131
        # MI: sum(P_ab * log(P_ab/P_ap_b))
132
        div = (pab + EPS) / (papb + EPS)
133
        return tf.reduce_sum(pab * tf.math.log(div + EPS), axis=[1, 2])
134
135
    def get_config(self) -> dict:
136
        """Return the config dictionary for recreating this class."""
137
        config = super().get_config()
138
        config["num_bins"] = self.num_bins
139
        config["sigma_ratio"] = self.sigma_ratio
140
        return config
141
142
143
@REGISTRY.register_loss(name="gmi")
144
class GlobalMutualInformationLoss(NegativeLossMixin, GlobalMutualInformation):
145
    """Revert the sign of GlobalMutualInformation."""
146
147
148
class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss):
149
    """
150
    Local squared zero-normalized cross-correlation.
151
152
    Denote y_true as t and y_pred as p. Consider a window having n elements.
153
    Each position in the window corresponds a weight w_i for i=1:n.
154
155
    Define the discrete expectation in the window E[t] as
156
157
        E[t] = sum_i(w_i * t_i) / sum_i(w_i)
158
159
    Similarly, the discrete variance in the window V[t] is
160
161
        V[t] = E[t**2] - E[t] ** 2
162
163
    The local squared zero-normalized cross-correlation is therefore
164
165
        E[ (t-E[t]) * (p-E[p]) ] ** 2 / V[t] / V[p]
166
167
    where the expectation in numerator is
168
169
        E[ (t-E[t]) * (p-E[p]) ] = E[t * p] - E[t] * E[p]
170
171
    Different kernel corresponds to different weights.
172
173
    For now, y_true and y_pred have to be at least 4d tensor, including batch axis.
174
175
    Reference:
176
177
        - Zero-normalized cross-correlation (ZNCC):
178
            https://en.wikipedia.org/wiki/Cross-correlation
179
        - Code: https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
180
    """
181
182
    kernel_fn_dict = dict(
183
        gaussian=gaussian_kernel1d,
184
        rectangular=rectangular_kernel1d,
185
        triangular=triangular_kernel1d,
186
    )
187
188
    def __init__(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
189
        self,
190
        kernel_size: int = 9,
191
        kernel_type: str = "rectangular",
192
        reduction: str = tf.keras.losses.Reduction.SUM,
193
        name: str = "LocalNormalizedCrossCorrelation",
194
    ):
195
        """
196
        Init.
197
198
        :param kernel_size: int. Kernel size or kernel sigma for kernel_type='gauss'.
199
        :param kernel_type: str, rectangular, triangular or gaussian
200
        :param reduction: using SUM reduction over batch axis,
201
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
202
        :param name: name of the loss
203
        """
204
        super().__init__(reduction=reduction, name=name)
205
        if kernel_type not in self.kernel_fn_dict.keys():
206
            raise ValueError(
207
                f"Wrong kernel_type {kernel_type} for LNCC loss type. "
208
                f"Feasible values are {self.kernel_fn_dict.keys()}"
209
            )
210
        self.kernel_fn = self.kernel_fn_dict[kernel_type]
211
        self.kernel_type = kernel_type
212
        self.kernel_size = kernel_size
213
214
        # (kernel_size, )
215
        self.kernel = self.kernel_fn(kernel_size=self.kernel_size)
216
        # E[1] = sum_i(w_i), ()
217
        self.kernel_vol = tf.reduce_sum(
218
            self.kernel[:, None, None]
219
            * self.kernel[None, :, None]
220
            * self.kernel[None, None, :]
221
        )
222
223
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
224
        """
225
        Return loss for a batch.
226
227
        :param y_true: shape = (batch, dim1, dim2, dim3)
228
            or (batch, dim1, dim2, dim3, ch)
229
        :param y_pred: shape = (batch, dim1, dim2, dim3)
230
            or (batch, dim1, dim2, dim3, ch)
231
        :return: shape = (batch,)
232
        """
233
        # adjust
234
        if len(y_true.shape) == 4:
235
            y_true = tf.expand_dims(y_true, axis=4)
236
            y_pred = tf.expand_dims(y_pred, axis=4)
237
        assert len(y_true.shape) == len(y_pred.shape) == 5
238
239
        # t = y_true, p = y_pred
240
        # (batch, dim1, dim2, dim3, ch)
241
        t2 = y_true * y_true
242
        p2 = y_pred * y_pred
243
        tp = y_true * y_pred
244
245
        # sum over kernel
246
        # (batch, dim1, dim2, dim3, 1)
247
        t_sum = separable_filter(y_true, kernel=self.kernel)  # E[t] * E[1]
248
        p_sum = separable_filter(y_pred, kernel=self.kernel)  # E[p] * E[1]
249
        t2_sum = separable_filter(t2, kernel=self.kernel)  # E[tt] * E[1]
250
        p2_sum = separable_filter(p2, kernel=self.kernel)  # E[pp] * E[1]
251
        tp_sum = separable_filter(tp, kernel=self.kernel)  # E[tp] * E[1]
252
253
        # average over kernel
254
        # (batch, dim1, dim2, dim3, 1)
255
        t_avg = t_sum / self.kernel_vol  # E[t]
256
        p_avg = p_sum / self.kernel_vol  # E[p]
257
258
        # shape = (batch, dim1, dim2, dim3, 1)
259
        cross = tp_sum - p_avg * t_sum  # E[tp] * E[1] - E[p] * E[t] * E[1]
260
        t_var = t2_sum - t_avg * t_sum  # V[t] * E[1]
261
        p_var = p2_sum - p_avg * p_sum  # V[p] * E[1]
262
263
        # (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p]
264
        ncc = (cross * cross + EPS) / (t_var * p_var + EPS)
265
266
        return tf.reduce_mean(ncc, axis=[1, 2, 3, 4])
267
268
    def get_config(self) -> dict:
269
        """Return the config dictionary for recreating this class."""
270
        config = super().get_config()
271
        config["kernel_size"] = self.kernel_size
272
        config["kernel_type"] = self.kernel_type
273
        return config
274
275
276
@REGISTRY.register_loss(name="lncc")
277
class LocalNormalizedCrossCorrelationLoss(
278
    NegativeLossMixin, LocalNormalizedCrossCorrelation
279
):
280
    """Revert the sign of LocalNormalizedCrossCorrelation."""
281
282
283
class GlobalNormalizedCrossCorrelation(tf.keras.losses.Loss):
284
    """
285
    Global squared zero-normalized cross-correlation.
286
287
    Compute the squared cross-correlation between the reference and moving images
288
    y_true and y_pred have to be at least 4d tensor, including batch axis.
289
290
    Reference:
291
292
        - Zero-normalized cross-correlation (ZNCC):
293
            https://en.wikipedia.org/wiki/Cross-correlation
294
295
    """
296
297
    def __init__(
298
        self,
299
        reduction: str = tf.keras.losses.Reduction.AUTO,
300
        name: str = "GlobalNormalizedCrossCorrelation",
301
    ):
302
        """
303
        Init.
304
        :param reduction: using AUTO reduction,
305
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
306
        :param name: name of the loss
307
        """
308
        super().__init__(reduction=reduction, name=name)
309
310
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
311
        """
312
        Return loss for a batch.
313
314
        :param y_true: shape = (batch, ...)
315
        :param y_pred: shape = (batch, ...)
316
        :return: shape = (batch,)
317
        """
318
319
        axis = [a for a in range(1, len(y_true.shape))]
0 ignored issues
show
Unused Code introduced by
Unnecessary use of a comprehension
Loading history...
320
        mu_pred = tf.reduce_mean(y_pred, axis=axis, keepdims=True)
321
        mu_true = tf.reduce_mean(y_true, axis=axis, keepdims=True)
322
        var_pred = tf.math.reduce_variance(y_pred, axis=axis)
323
        var_true = tf.math.reduce_variance(y_true, axis=axis)
324
        numerator = tf.abs(
325
            tf.reduce_mean((y_pred - mu_pred) * (y_true - mu_true), axis=axis)
326
        )
327
328
        return (numerator * numerator + EPS) / (var_pred * var_true + EPS)
329
330
331
@REGISTRY.register_loss(name="gncc")
332
class GlobalNormalizedCrossCorrelationLoss(
333
    NegativeLossMixin, GlobalNormalizedCrossCorrelation
334
):
335
    """Revert the sign of GlobalNormalizedCrossCorrelation."""
336