Completed
Push — main ( bbe77b...232bc8 )
by Yunguan
22s queued 13s
created

deepreg.loss.label.JaccardIndex.get_config()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 5
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Provide different loss or metrics classes for labels."""
2
3
from typing import List, Optional
4
5
import tensorflow as tf
6
7
from deepreg.constant import EPS
8
from deepreg.loss.util import NegativeLossMixin, cauchy_kernel1d
9
from deepreg.loss.util import gaussian_kernel1d_sigma as gaussian_kernel1d
10
from deepreg.loss.util import separable_filter
11
from deepreg.registry import REGISTRY
12
13
14
class MultiScaleLoss(tf.keras.losses.Loss):
15
    """
16
    Base class for multi-scale loss.
17
18
    It applies the loss at different scales (gaussian or cauchy smoothing).
19
    It is assumed that loss values are between 0 and 1.
20
    """
21
22
    kernel_fn_dict = dict(gaussian=gaussian_kernel1d, cauchy=cauchy_kernel1d)
23
24
    def __init__(
25
        self,
26
        scales: Optional[List] = None,
27
        kernel: str = "gaussian",
28
        reduction: str = tf.keras.losses.Reduction.NONE,
29
        name: str = "MultiScaleLoss",
30
    ):
31
        """
32
        Init.
33
34
        :param scales: list of scalars or None, if None, do not apply any scaling.
35
        :param kernel: gaussian or cauchy.
36
        :param reduction: do not perform reduction over batch axis.
37
            this is for supporting multi-device training,
38
            model.fit() will average over global batch size automatically.
39
            Loss returns a tensor of shape (batch, ).
40
        :param name: str, name of the loss.
41
        """
42
        super().__init__(reduction=reduction, name=name)
43
        assert kernel in ["gaussian", "cauchy"]
44
        self.scales = scales
45
        self.kernel = kernel
46
47
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
48
        """
49
        Use _call to calculate loss at different scales.
50
51
        :param y_true: ground-truth tensor, shape = (batch, dim1, dim2, dim3).
52
        :param y_pred: predicted tensor, shape = (batch, dim1, dim2, dim3).
53
        :return: multi-scale loss, shape = (batch, ).
54
        """
55
        if self.scales is None:
56
            return self._call(y_true=y_true, y_pred=y_pred)
57
        kernel_fn = self.kernel_fn_dict[self.kernel]
58
        losses = []
59
        for s in self.scales:
60
            if s == 0:
61
                # no smoothing
62
                losses.append(
63
                    self._call(
64
                        y_true=y_true,
65
                        y_pred=y_pred,
66
                    )
67
                )
68
            else:
69
                losses.append(
70
                    self._call(
71
                        y_true=separable_filter(
72
                            tf.expand_dims(y_true, axis=4), kernel_fn(s)
73
                        )[..., 0],
74
                        y_pred=separable_filter(
75
                            tf.expand_dims(y_pred, axis=4), kernel_fn(s)
76
                        )[..., 0],
77
                    )
78
                )
79
        loss = tf.add_n(losses)
80
        loss = loss / len(self.scales)
81
        return loss
82
83
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
84
        """
85
        Return loss for a batch.
86
87
        :param y_true: ground-truth tensor.
88
        :param y_pred: predicted tensor.
89
        :return: negated loss.
90
        """
91
        raise NotImplementedError
92
93
    def get_config(self) -> dict:
94
        """Return the config dictionary for recreating this class."""
95
        config = super().get_config()
96
        config["scales"] = self.scales
97
        config["kernel"] = self.kernel
98
        return config
99
100
101
class DiceScore(MultiScaleLoss):
102
    """
103
    Define dice score.
104
105
    The formulation is:
106
107
        0. w_fg + w_bg = 1
108
        1. let y_prod = y_true * y_pred and y_sum  = y_true + y_pred
109
        2. num = 2 *  (w_fg * y_true * y_pred + w_bg * (1−y_true) * (1−y_pred))
110
               = 2 *  ((w_fg+w_bg) * y_prod - w_bg * y_sum + w_bg)
111
               = 2 *  (y_prod - w_bg * y_sum + w_bg)
112
        3. denom = (w_fg * (y_true + y_pred) + w_bg * (1−y_true + 1−y_pred))
113
                 = (w_fg-w_bg) * y_sum + 2 * w_bg
114
                 = (1-2*w_bg) * y_sum + 2 * w_bg
115
        4. dice score = num / denom
116
117
    where num and denom are summed over all axes except the batch axis.
118
119
    Reference:
120
        Sudre, Carole H., et al. "Generalised dice overlap as a deep learning loss
121
        function for highly unbalanced segmentations." Deep learning in medical image
122
        analysis and multimodal learning for clinical decision support.
123
        Springer, Cham, 2017. 240-248.
124
    """
125
126
    def __init__(
127
        self,
128
        binary: bool = False,
129
        background_weight: float = 0.0,
130
        smooth_nr: float = EPS,
131
        smooth_dr: float = EPS,
132
        scales: Optional[List] = None,
133
        kernel: str = "gaussian",
134
        reduction: str = tf.keras.losses.Reduction.NONE,
135
        name: str = "DiceScore",
136
    ):
137
        """
138
        Init.
139
140
        :param binary: if True, project y_true, y_pred to 0 or 1.
141
        :param background_weight: weight for background, where y == 0.
142
        :param smooth_nr: small constant added to numerator in case of zero covariance.
143
        :param smooth_dr: small constant added to denominator in case of zero variance.
144
        :param scales: list of scalars or None, if None, do not apply any scaling.
145
        :param kernel: gaussian or cauchy.
146
        :param reduction: do not perform reduction over batch axis.
147
            this is for supporting multi-device training,
148
            model.fit() will average over global batch size automatically.
149
            Loss returns a tensor of shape (batch, ).
150
        :param name: str, name of the loss.
151
        """
152
        super().__init__(scales=scales, kernel=kernel, reduction=reduction, name=name)
153
        if background_weight < 0 or background_weight > 1:
154
            raise ValueError(
155
                "The background weight for Dice Score must be "
156
                f"within [0, 1], got {background_weight}."
157
            )
158
159
        self.binary = binary
160
        self.background_weight = background_weight
161
        self.smooth_nr = smooth_nr
162
        self.smooth_dr = smooth_dr
163
        self.flatten = tf.keras.layers.Flatten()
164
165
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
166
        """
167
        Return loss for a batch.
168
169
        :param y_true: shape = (batch, ...)
170
        :param y_pred: shape = (batch, ...)
171
        :return: shape = (batch,)
172
        """
173
        if self.binary:
174
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
175
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
176
177
        # (batch, ...) -> (batch, d)
178
        y_true = self.flatten(y_true)
179
        y_pred = self.flatten(y_pred)
180
181
        # for foreground class
182
        y_prod = tf.reduce_sum(y_true * y_pred, axis=1)
183
        y_sum = tf.reduce_sum(y_true + y_pred, axis=1)
184
185
        if self.background_weight > 0:
186
            # generalized
187
            vol = tf.reduce_sum(tf.ones_like(y_true), axis=1)
188
            numerator = 2 * (
189
                y_prod - self.background_weight * y_sum + self.background_weight * vol
190
            )
191
            denominator = (
192
                1 - 2 * self.background_weight
193
            ) * y_sum + 2 * self.background_weight * vol
194
        else:
195
            # foreground only
196
            numerator = 2 * y_prod
197
            denominator = y_sum
198
199
        return (numerator + self.smooth_nr) / (denominator + self.smooth_dr)
200
201
    def get_config(self) -> dict:
202
        """Return the config dictionary for recreating this class."""
203
        config = super().get_config()
204
        config.update(
205
            binary=self.binary,
206
            background_weight=self.background_weight,
207
            smooth_nr=self.smooth_nr,
208
            smooth_dr=self.smooth_dr,
209
        )
210
        return config
211
212
213
@REGISTRY.register_loss(name="dice")
214
class DiceLoss(NegativeLossMixin, DiceScore):
215
    """Revert the sign of DiceScore."""
216
217
218
@REGISTRY.register_loss(name="cross-entropy")
219
class CrossEntropy(MultiScaleLoss):
220
    """
221
    Define weighted cross-entropy.
222
223
    The formulation is:
224
        loss = − w_fg * y_true log(y_pred) - w_bg * (1−y_true) log(1−y_pred)
225
    """
226
227
    def __init__(
228
        self,
229
        binary: bool = False,
230
        background_weight: float = 0.0,
231
        smooth: float = EPS,
232
        scales: Optional[List] = None,
233
        kernel: str = "gaussian",
234
        reduction: str = tf.keras.losses.Reduction.NONE,
235
        name: str = "CrossEntropy",
236
    ):
237
        """
238
        Init.
239
240
        :param binary: if True, project y_true, y_pred to 0 or 1
241
        :param background_weight: weight for background, where y == 0.
242
        :param scales: list of scalars or None, if None, do not apply any scaling.
243
        :param kernel: gaussian or cauchy.
244
        :param reduction: do not perform reduction over batch axis.
245
            this is for supporting multi-device training,
246
            model.fit() will average over global batch size automatically.
247
            Loss returns a tensor of shape (batch, ).
248
        :param name: str, name of the loss.
249
        """
250
        super().__init__(scales=scales, kernel=kernel, reduction=reduction, name=name)
251
        if background_weight < 0 or background_weight > 1:
252
            raise ValueError(
253
                "The background weight for Cross Entropy must be "
254
                f"within [0, 1], got {background_weight}."
255
            )
256
        self.binary = binary
257
        self.background_weight = background_weight
258
        self.smooth = smooth
259
        self.flatten = tf.keras.layers.Flatten()
260
261
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
262
        """
263
        Return loss for a batch.
264
265
        :param y_true: shape = (batch, ...)
266
        :param y_pred: shape = (batch, ...)
267
        :return: shape = (batch,)
268
        """
269
        if self.binary:
270
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
271
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
272
273
        # (batch, ...) -> (batch, d)
274
        y_true = self.flatten(y_true)
275
        y_pred = self.flatten(y_pred)
276
277
        loss_fg = -tf.reduce_mean(y_true * tf.math.log(y_pred + self.smooth), axis=1)
278
        if self.background_weight > 0:
279
            loss_bg = -tf.reduce_mean(
280
                (1 - y_true) * tf.math.log(1 - y_pred + self.smooth), axis=1
281
            )
282
            return (
283
                1 - self.background_weight
284
            ) * loss_fg + self.background_weight * loss_bg
285
        else:
286
            return loss_fg
287
288
    def get_config(self) -> dict:
289
        """Return the config dictionary for recreating this class."""
290
        config = super().get_config()
291
        config.update(
292
            binary=self.binary,
293
            background_weight=self.background_weight,
294
            smooth=self.smooth,
295
        )
296
        return config
297
298
299
class JaccardIndex(DiceScore):
300
    """
301
    Define Jaccard index.
302
303
    The formulation is:
304
    1. num = y_true * y_pred
305
    2. denom = y_true + y_pred - y_true * y_pred
306
    3. Jaccard index = num / denom
307
308
        0. w_fg + w_bg = 1
309
        1. let y_prod = y_true * y_pred and y_sum  = y_true + y_pred
310
        2. num = (w_fg * y_true * y_pred + w_bg * (1−y_true) * (1−y_pred))
311
               = ((w_fg+w_bg) * y_prod - w_bg * y_sum + w_bg)
312
               = (y_prod - w_bg * y_sum + w_bg)
313
        3. denom = (w_fg * (y_true + y_pred - y_true * y_pred)
314
                  + w_bg * (1−y_true + 1−y_pred - (1−y_true) * (1−y_pred)))
315
                 = w_fg * (y_sum - y_prod) + w_bg * (1-y_prod)
316
                 = (1-w_bg) * y_sum - y_prod + w_bg
317
        4. dice score = num / denom
318
    """
319
320
    def __init__(
321
        self,
322
        binary: bool = False,
323
        background_weight: float = 0.0,
324
        smooth_nr: float = EPS,
325
        smooth_dr: float = EPS,
326
        scales: Optional[List] = None,
327
        kernel: str = "gaussian",
328
        reduction: str = tf.keras.losses.Reduction.NONE,
329
        name: str = "JaccardIndex",
330
    ):
331
        """
332
        Init.
333
334
        :param binary: if True, project y_true, y_pred to 0 or 1.
335
        :param background_weight: weight for background, where y == 0.
336
        :param smooth_nr: small constant added to numerator in case of zero covariance.
337
        :param smooth_dr: small constant added to denominator in case of zero variance.
338
        :param scales: list of scalars or None, if None, do not apply any scaling.
339
        :param kernel: gaussian or cauchy.
340
        :param reduction: do not perform reduction over batch axis.
341
            this is for supporting multi-device training,
342
            model.fit() will average over global batch size automatically.
343
            Loss returns a tensor of shape (batch, ).
344
        :param name: str, name of the loss.
345
        """
346
        super().__init__(
347
            binary=binary,
348
            background_weight=background_weight,
349
            smooth_nr=smooth_nr,
350
            smooth_dr=smooth_dr,
351
            scales=scales,
352
            kernel=kernel,
353
            reduction=reduction,
354
            name=name,
355
        )
356
357
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
358
        """
359
        Return loss for a batch.
360
361
        :param y_true: shape = (batch, ...)
362
        :param y_pred: shape = (batch, ...)
363
        :return: shape = (batch,)
364
        """
365
        if self.binary:
366
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
367
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
368
369
        # (batch, ...) -> (batch, d)
370
        y_true = self.flatten(y_true)
371
        y_pred = self.flatten(y_pred)
372
373
        # for foreground class
374
        y_prod = tf.reduce_sum(y_true * y_pred, axis=1)
375
        y_sum = tf.reduce_sum(y_true + y_pred, axis=1)
376
377
        if self.background_weight > 0:
378
            # generalized
379
            vol = tf.reduce_sum(tf.ones_like(y_true), axis=1)
380
            numerator = (
381
                y_prod - self.background_weight * y_sum + self.background_weight * vol
382
            )
383
            denominator = (
384
                (1 - self.background_weight) * y_sum
385
                - y_prod
386
                + self.background_weight * vol
387
            )
388
        else:
389
            # foreground only
390
            numerator = y_prod
391
            denominator = y_sum - y_prod
392
393
        return (numerator + self.smooth_nr) / (denominator + self.smooth_dr)
394
395
396
@REGISTRY.register_loss(name="jaccard")
397
class JaccardLoss(NegativeLossMixin, JaccardIndex):
398
    """Revert the sign of JaccardIndex."""
399
400
401
def compute_centroid(mask: tf.Tensor, grid: tf.Tensor) -> tf.Tensor:
402
    """
403
    Calculate the centroid of the mask.
404
    :param mask: shape = (batch, dim1, dim2, dim3)
405
    :param grid: shape = (1, dim1, dim2, dim3, 3)
406
    :return: shape = (batch, 3), batch of vectors denoting
407
             location of centroids.
408
    """
409
    assert len(mask.shape) == 4
410
    assert len(grid.shape) == 5
411
    bool_mask = tf.expand_dims(
412
        tf.cast(mask >= 0.5, dtype=tf.float32), axis=4
413
    )  # (batch, dim1, dim2, dim3, 1)
414
    masked_grid = bool_mask * grid  # (batch, dim1, dim2, dim3, 3)
415
    numerator = tf.reduce_sum(masked_grid, axis=[1, 2, 3])  # (batch, 3)
416
    denominator = tf.reduce_sum(bool_mask, axis=[1, 2, 3])  # (batch, 1)
417
    return (numerator + EPS) / (denominator + EPS)  # (batch, 3)
418
419
420
def compute_centroid_distance(
421
    y_true: tf.Tensor, y_pred: tf.Tensor, grid: tf.Tensor
422
) -> tf.Tensor:
423
    """
424
    Calculate the L2-distance between two tensors' centroids.
425
    :param y_true: tensor, shape = (batch, dim1, dim2, dim3)
426
    :param y_pred: tensor, shape = (batch, dim1, dim2, dim3)
427
    :param grid: tensor, shape = (1, dim1, dim2, dim3, 3)
428
    :return: shape = (batch,)
429
    """
430
    centroid_1 = compute_centroid(mask=y_pred, grid=grid)  # (batch, 3)
431
    centroid_2 = compute_centroid(mask=y_true, grid=grid)  # (batch, 3)
432
    return tf.sqrt(tf.reduce_sum((centroid_1 - centroid_2) ** 2, axis=1))
433
434
435
def foreground_proportion(y: tf.Tensor) -> tf.Tensor:
436
    """
437
    Calculate the percentage of foreground vs background per 3d volume.
438
    :param y: shape = (batch, dim1, dim2, dim3), a 3D label tensor
439
    :return: shape = (batch,)
440
    """
441
    y = tf.cast(y >= 0.5, dtype=tf.float32)
442
    return tf.reduce_sum(y, axis=[1, 2, 3]) / tf.reduce_sum(
443
        tf.ones_like(y), axis=[1, 2, 3]
444
    )
445