Completed
Push — main ( 035ba4...a182f9 )
by Yunguan
19s queued 12s
created

GlobalMutualInformation.__init__()   A

Complexity

Conditions 1

Size

Total Lines 19
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 19
rs 9.95
c 0
b 0
f 0
cc 1
nop 5
1
"""Provide different loss or metrics classes for images."""
2
3
import tensorflow as tf
4
5
from deepreg.loss.util import NegativeLossMixin
6
from deepreg.registry import REGISTRY
7
8
EPS = tf.keras.backend.epsilon()
9
10
11
@REGISTRY.register_loss(name="ssd")
12
class SumSquaredDifference(tf.keras.losses.Loss):
13
    """
14
    Sum of squared distance between y_true and y_pred.
15
16
    y_true and y_pred have to be at least 1d tensor, including batch axis.
17
    """
18
19
    def __init__(
20
        self,
21
        reduction: str = tf.keras.losses.Reduction.SUM,
22
        name: str = "SumSquaredDifference",
23
    ):
24
        """
25
        Init.
26
27
        :param reduction: using SUM reduction over batch axis,
28
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
29
        :param name: name of the loss
30
        """
31
        super().__init__(reduction=reduction, name=name)
32
33
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
34
        """
35
        Return loss for a batch.
36
37
        :param y_true: shape = (batch, ...)
38
        :param y_pred: shape = (batch, ...)
39
        :return: shape = (batch,)
40
        """
41
        loss = tf.math.squared_difference(y_true, y_pred)
42
        loss = tf.keras.layers.Flatten()(loss)
43
        return tf.reduce_mean(loss, axis=1)
44
45
46
class GlobalMutualInformation(tf.keras.losses.Loss):
47
    """
48
    Differentiable global mutual information via Parzen windowing method.
49
50
    y_true and y_pred have to be at least 4d tensor, including batch axis.
51
52
    Reference: https://dspace.mit.edu/handle/1721.1/123142,
53
        Section 3.1, equation 3.1-3.5, Algorithm 1
54
    """
55
56
    def __init__(
57
        self,
58
        num_bins: int = 23,
59
        sigma_ratio: float = 0.5,
60
        reduction: str = tf.keras.losses.Reduction.SUM,
61
        name: str = "GlobalMutualInformation",
62
    ):
63
        """
64
        Init.
65
66
        :param num_bins: number of bins for intensity, the default value is empirical.
67
        :param sigma_ratio: a hyper param for gaussian function
68
        :param reduction: using SUM reduction over batch axis,
69
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
70
        :param name: name of the loss
71
        """
72
        super().__init__(reduction=reduction, name=name)
73
        self.num_bins = num_bins
74
        self.sigma_ratio = sigma_ratio
75
76
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
77
        """
78
        Return loss for a batch.
79
80
        :param y_true: shape = (batch, dim1, dim2, dim3)
81
            or (batch, dim1, dim2, dim3, ch)
82
        :param y_pred: shape = (batch, dim1, dim2, dim3)
83
            or (batch, dim1, dim2, dim3, ch)
84
        :return: shape = (batch,)
85
        """
86
        # adjust
87
        if len(y_true.shape) == 4:
88
            y_true = tf.expand_dims(y_true, axis=4)
89
            y_pred = tf.expand_dims(y_pred, axis=4)
90
        assert len(y_true.shape) == len(y_pred.shape) == 5
91
92
        # intensity is split into bins between 0, 1
93
        y_true = tf.clip_by_value(y_true, 0, 1)
94
        y_pred = tf.clip_by_value(y_pred, 0, 1)
95
        bin_centers = tf.linspace(0.0, 1.0, self.num_bins)  # (num_bins,)
96
        bin_centers = tf.cast(bin_centers, dtype=y_true.dtype)
97
        bin_centers = bin_centers[None, None, ...]  # (1, 1, num_bins)
98
        sigma = (
99
            tf.reduce_mean(bin_centers[:, :, 1:] - bin_centers[:, :, :-1])
100
            * self.sigma_ratio
101
        )  # scalar, sigma in the Gaussian function (weighting function W)
102
        preterm = 1 / (2 * tf.math.square(sigma))  # scalar
103
        batch, w, h, z, c = y_true.shape
104
        y_true = tf.reshape(y_true, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
105
        y_pred = tf.reshape(y_pred, [batch, w * h * z * c, 1])  # (batch, nb_voxels, 1)
106
        nb_voxels = y_true.shape[1] * 1.0  # w * h * z, number of voxels
107
108
        # each voxel contributes continuously to a range of histogram bin
109
        ia = tf.math.exp(
110
            -preterm * tf.math.square(y_true - bin_centers)
111
        )  # (batch, nb_voxels, num_bins)
112
        ia /= tf.reduce_sum(ia, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
113
        ia = tf.transpose(ia, (0, 2, 1))  # (batch, num_bins, nb_voxels)
114
        pa = tf.reduce_mean(ia, axis=-1, keepdims=True)  # (batch, num_bins, 1)
115
116
        ib = tf.math.exp(
117
            -preterm * tf.math.square(y_pred - bin_centers)
118
        )  # (batch, nb_voxels, num_bins)
119
        ib /= tf.reduce_sum(ib, -1, keepdims=True)  # (batch, nb_voxels, num_bins)
120
        pb = tf.reduce_mean(ib, axis=1, keepdims=True)  # (batch, 1, num_bins)
121
122
        papb = tf.matmul(pa, pb)  # (batch, num_bins, num_bins)
123
        pab = tf.matmul(ia, ib)  # (batch, num_bins, num_bins)
124
        pab /= nb_voxels
125
126
        # MI: sum(P_ab * log(P_ab/P_ap_b))
127
        div = (pab + EPS) / (papb + EPS)
128
        return tf.reduce_sum(pab * tf.math.log(div + EPS), axis=[1, 2])
129
130
    def get_config(self) -> dict:
131
        """Return the config dictionary for recreating this class."""
132
        config = super().get_config()
133
        config["num_bins"] = self.num_bins
134
        config["sigma_ratio"] = self.sigma_ratio
135
        return config
136
137
138
@REGISTRY.register_loss(name="gmi")
139
class GlobalMutualInformationLoss(NegativeLossMixin, GlobalMutualInformation):
140
    """Revert the sign of GlobalMutualInformation."""
141
142
143
def build_rectangular_kernel(
144
    kernel_size: int, input_channel: int
145
) -> (tf.Tensor, tf.Tensor):
146
    """
147
    Return a rectangular kernel for LocalNormalizedCrossCorrelation.
148
149
    :param kernel_size: size of the kernel for convolution.
150
    :param input_channel: number of channels for input
151
    :return:
152
        - filters, of shape (kernel_size, kernel_size, kernel_size, ch, 1)
153
        - kernel_vol, scalar
154
    """
155
    filters = tf.ones(shape=(kernel_size, kernel_size, kernel_size, input_channel, 1))
156
    kernel_vol = kernel_size ** 3
157
    return filters, tf.constant(kernel_vol)
158
159
160
def build_triangular_kernel(
161
    kernel_size: int, input_channel: int
162
) -> (tf.Tensor, tf.Tensor):
163
    """
164
    Return a triangular kernel for LocalNormalizedCrossCorrelation.
165
166
    :param kernel_size: size of the kernel for convolution.
167
    :param input_channel: number of channels for input
168
    :return:
169
        - filters, of shape (kernel_size-1, kernel_size-1, kernel_size-1, ch, 1)
170
        - kernel_vol, scalar
171
    """
172
    fsize = int((kernel_size + 1) / 2)
173
    pad_filter = tf.constant(
174
        [
175
            [0, 0],
176
            [int((fsize - 1) / 2), int((fsize + 1) / 2)],
177
            [int((fsize - 1) / 2), int((fsize + 1) / 2)],
178
            [int((fsize - 1) / 2), int((fsize + 1) / 2)],
179
            [0, 0],
180
        ]
181
    )
182
183
    f1 = tf.ones(shape=(1, fsize, fsize, fsize, 1)) / fsize
184
    f1 = tf.pad(f1, pad_filter, "CONSTANT")
185
    f2 = tf.ones(shape=(fsize, fsize, fsize, 1, input_channel)) / fsize
186
187
    filters = tf.nn.conv3d(f1, f2, strides=[1, 1, 1, 1, 1], padding="SAME")
188
    filters = tf.transpose(filters, perm=[1, 2, 3, 4, 0])
189
    kernel_vol = tf.reduce_sum(filters ** 2)
190
191
    return filters, kernel_vol
192
193
194
def build_gaussian_kernel(
195
    kernel_size: int, input_channel: int
196
) -> (tf.Tensor, tf.Tensor):
197
    """
198
    Return a Gaussian kernel for LocalNormalizedCrossCorrelation.
199
200
    :param kernel_size: size of the kernel for convolution.
201
    :param input_channel: number of channels for input
202
    :return:
203
        - filters, of shape (kernel_size, kernel_size, kernel_size, ch, 1)
204
        - kernel_vol, scalar
205
    """
206
    mean = (kernel_size - 1) / 2.0
207
    sigma = kernel_size / 3
208
209
    grid_dim = tf.range(0, kernel_size)
210
    grid_dim_ch = tf.range(0, input_channel)
211
    grid = tf.expand_dims(
212
        tf.cast(
213
            tf.stack(tf.meshgrid(grid_dim, grid_dim, grid_dim, grid_dim_ch), 0),
214
            dtype="float32",
215
        ),
216
        axis=-1,
217
    )
218
    filters = tf.exp(-tf.reduce_sum(tf.square(grid - mean), axis=0) / (2 * sigma ** 2))
219
    kernel_vol = tf.reduce_sum(filters ** 2)
220
221
    return filters, kernel_vol
222
223
224
class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss):
225
    """
226
    Local squared zero-normalized cross-correlation.
227
228
    The loss is based on a moving kernel/window over the y_true/y_pred,
229
    within the window the square of zncc is calculated.
230
    The kernel can be a rectangular / triangular / gaussian window.
231
    The final loss is the averaged loss over all windows.
232
    y_true and y_pred have to be at least 4d tensor, including batch axis.
233
234
    Reference:
235
236
        - Zero-normalized cross-correlation (ZNCC):
237
            https://en.wikipedia.org/wiki/Cross-correlation
238
        - Code: https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
239
    """
240
241
    kernel_fn_dict = dict(
242
        gaussian=build_gaussian_kernel,
243
        rectangular=build_rectangular_kernel,
244
        triangular=build_triangular_kernel,
245
    )
246
247
    def __init__(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
248
        self,
249
        kernel_size: int = 9,
250
        kernel_type: str = "rectangular",
251
        reduction: str = tf.keras.losses.Reduction.SUM,
252
        name: str = "LocalNormalizedCrossCorrelation",
253
    ):
254
        """
255
        Init.
256
257
        :param kernel_size: int. Kernel size or kernel sigma for kernel_type='gauss'.
258
        :param kernel_type: str, rectangular, triangular or gaussian
259
        :param reduction: using SUM reduction over batch axis,
260
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
261
        :param name: name of the loss
262
        """
263
        super().__init__(reduction=reduction, name=name)
264
        if kernel_type not in self.kernel_fn_dict.keys():
265
            raise ValueError(
266
                f"Wrong kernel_type {kernel_type} for LNCC loss type. "
267
                f"Feasible values are {self.kernel_fn_dict.keys()}"
268
            )
269
        self.kernel_fn = self.kernel_fn_dict[kernel_type]
270
        self.kernel_type = kernel_type
271
        self.kernel_size = kernel_size
272
273
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
274
        """
275
        Return loss for a batch.
276
277
        :param y_true: shape = (batch, dim1, dim2, dim3)
278
            or (batch, dim1, dim2, dim3, ch)
279
        :param y_pred: shape = (batch, dim1, dim2, dim3)
280
            or (batch, dim1, dim2, dim3, ch)
281
        :return: shape = (batch,)
282
        """
283
        # adjust
284
        if len(y_true.shape) == 4:
285
            y_true = tf.expand_dims(y_true, axis=4)
286
            y_pred = tf.expand_dims(y_pred, axis=4)
287
        assert len(y_true.shape) == len(y_pred.shape) == 5
288
289
        filters, kernel_vol = self.kernel_fn(
290
            kernel_size=self.kernel_size,
291
            input_channel=y_true.shape[4],
292
        )
293
        filters = tf.cast(filters, dtype=y_true.dtype)
294
        kernel_vol = tf.cast(kernel_vol, dtype=y_true.dtype)
295
        strides = [1, 1, 1, 1, 1]
296
        padding = "SAME"
297
298
        # t = y_true, p = y_pred
299
        # (batch, dim1, dim2, dim3, ch)
300
        t2 = y_true * y_true
301
        p2 = y_pred * y_pred
302
        tp = y_true * y_pred
303
304
        # sum over kernel
305
        # (batch, dim1, dim2, dim3, 1)
306
        t_sum = tf.nn.conv3d(y_true, filters=filters, strides=strides, padding=padding)
307
        p_sum = tf.nn.conv3d(y_pred, filters=filters, strides=strides, padding=padding)
308
        t2_sum = tf.nn.conv3d(t2, filters=filters, strides=strides, padding=padding)
309
        p2_sum = tf.nn.conv3d(p2, filters=filters, strides=strides, padding=padding)
310
        tp_sum = tf.nn.conv3d(tp, filters=filters, strides=strides, padding=padding)
311
312
        # average over kernel
313
        # (batch, dim1, dim2, dim3, 1)
314
        t_avg = t_sum / kernel_vol
315
        p_avg = p_sum / kernel_vol
316
317
        # normalized cross correlation between t and p
318
        # sum[(t - mean[t]) * (p - mean[p])] / std[t] / std[p]
319
        # denoted by num / denom
320
        # assume we sum over N values
321
        # num = sum[t * p - mean[t] * p - t * mean[p] + mean[t] * mean[p]]
322
        #     = sum[t*p] - sum[t] * sum[p] / N * 2 + sum[t] * sum[p] / N
323
        #     = sum[t*p] - sum[t] * sum[p] / N
324
        #     = sum[t*p] - sum[t] * mean[p] = cross
325
        # the following is actually squared ncc
326
        # shape = (batch, dim1, dim2, dim3, 1)
327
        cross = tp_sum - p_avg * t_sum
328
        t_var = t2_sum - t_avg * t_sum  # std[t] ** 2
329
        p_var = p2_sum - p_avg * p_sum  # std[p] ** 2
330
        ncc = (cross * cross + EPS) / (t_var * p_var + EPS)
331
        return tf.reduce_mean(ncc, axis=[1, 2, 3, 4])
332
333
    def get_config(self) -> dict:
334
        """Return the config dictionary for recreating this class."""
335
        config = super().get_config()
336
        config["kernel_size"] = self.kernel_size
337
        config["kernel_type"] = self.kernel_type
338
        return config
339
340
341
@REGISTRY.register_loss(name="lncc")
342
class LocalNormalizedCrossCorrelationLoss(
343
    NegativeLossMixin, LocalNormalizedCrossCorrelation
344
):
345
    """Revert the sign of LocalNormalizedCrossCorrelation."""
346
347
348
class GlobalNormalizedCrossCorrelation(tf.keras.losses.Loss):
349
    """
350
    Global squared zero-normalized cross-correlation.
351
352
    Compute the squared cross-correlation between the reference and moving images
353
    y_true and y_pred have to be at least 4d tensor, including batch axis.
354
355
    Reference:
356
357
        - Zero-normalized cross-correlation (ZNCC):
358
            https://en.wikipedia.org/wiki/Cross-correlation
359
360
    """
361
362
    def __init__(
363
        self,
364
        reduction: str = tf.keras.losses.Reduction.AUTO,
365
        name: str = "GlobalNormalizedCrossCorrelation",
366
    ):
367
        """
368
        Init.
369
        :param reduction: using AUTO reduction,
370
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
371
        :param name: name of the loss
372
        """
373
        super().__init__(reduction=reduction, name=name)
374
375
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
376
        """
377
        Return loss for a batch.
378
379
        :param y_true: shape = (batch, ...)
380
        :param y_pred: shape = (batch, ...)
381
        :return: shape = (batch,)
382
        """
383
384
        axis = [a for a in range(1, len(y_true.shape))]
0 ignored issues
show
Unused Code introduced by
Unnecessary use of a comprehension
Loading history...
385
        mu_pred = tf.reduce_mean(y_pred, axis=axis, keepdims=True)
386
        mu_true = tf.reduce_mean(y_true, axis=axis, keepdims=True)
387
        var_pred = tf.math.reduce_variance(y_pred, axis=axis)
388
        var_true = tf.math.reduce_variance(y_true, axis=axis)
389
        numerator = tf.abs(
390
            tf.reduce_mean((y_pred - mu_pred) * (y_true - mu_true), axis=axis)
391
        )
392
393
        return (numerator * numerator + EPS) / (var_pred * var_true + EPS)
394
395
396
@REGISTRY.register_loss(name="gncc")
397
class GlobalNormalizedCrossCorrelationLoss(
398
    NegativeLossMixin, GlobalNormalizedCrossCorrelation
399
):
400
    """Revert the sign of GlobalNormalizedCrossCorrelation."""
401