Passed
Pull Request — main (#736)
by Yunguan
01:30
created

deepreg.loss.label.DiceScore.get_config()   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 8
dl 0
loc 10
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Provide different loss or metrics classes for labels."""
2
3
from typing import List, Optional
4
5
import tensorflow as tf
6
7
from deepreg.constant import EPS
8
from deepreg.loss.util import NegativeLossMixin, cauchy_kernel1d
9
from deepreg.loss.util import gaussian_kernel1d_sigma as gaussian_kernel1d
10
from deepreg.loss.util import separable_filter
11
from deepreg.registry import REGISTRY
12
13
14
class MultiScaleLoss(tf.keras.losses.Loss):
15
    """
16
    Base class for multi-scale loss.
17
18
    It applies the loss at different scales (gaussian or cauchy smoothing).
19
    It is assumed that loss values are between 0 and 1.
20
    """
21
22
    kernel_fn_dict = dict(gaussian=gaussian_kernel1d, cauchy=cauchy_kernel1d)
23
24
    def __init__(
25
        self,
26
        scales: Optional[List] = None,
27
        kernel: str = "gaussian",
28
        reduction: str = tf.keras.losses.Reduction.SUM,
29
        name: str = "MultiScaleLoss",
30
    ):
31
        """
32
        Init.
33
34
        :param scales: list of scalars or None, if None, do not apply any scaling.
35
        :param kernel: gaussian or cauchy.
36
        :param reduction: using SUM reduction over batch axis,
37
            this is for supporting multi-device training,
38
            and the loss will be divided by global batch size,
39
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
40
        :param name: str, name of the loss.
41
        """
42
        super().__init__(reduction=reduction, name=name)
43
        assert kernel in ["gaussian", "cauchy"]
44
        self.scales = scales
45
        self.kernel = kernel
46
47
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
48
        """
49
        Use _call to calculate loss at different scales.
50
51
        :param y_true: ground-truth tensor, shape = (batch, dim1, dim2, dim3).
52
        :param y_pred: predicted tensor, shape = (batch, dim1, dim2, dim3).
53
        :return: multi-scale loss, shape = (batch, ).
54
        """
55
        if self.scales is None:
56
            return self._call(y_true=y_true, y_pred=y_pred)
57
        kernel_fn = self.kernel_fn_dict[self.kernel]
58
        losses = []
59
        for s in self.scales:
60
            if s == 0:
61
                # no smoothing
62
                losses.append(
63
                    self._call(
64
                        y_true=y_true,
65
                        y_pred=y_pred,
66
                    )
67
                )
68
            else:
69
                losses.append(
70
                    self._call(
71
                        y_true=separable_filter(
72
                            tf.expand_dims(y_true, axis=4), kernel_fn(s)
73
                        )[..., 0],
74
                        y_pred=separable_filter(
75
                            tf.expand_dims(y_pred, axis=4), kernel_fn(s)
76
                        )[..., 0],
77
                    )
78
                )
79
        loss = tf.add_n(losses)
80
        loss = loss / len(self.scales)
81
        return loss
82
83
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
84
        """
85
        Return loss for a batch.
86
87
        :param y_true: ground-truth tensor.
88
        :param y_pred: predicted tensor.
89
        :return: negated loss.
90
        """
91
        raise NotImplementedError
92
93
    def get_config(self) -> dict:
94
        """Return the config dictionary for recreating this class."""
95
        config = super().get_config()
96
        config["scales"] = self.scales
97
        config["kernel"] = self.kernel
98
        return config
99
100
101
class DiceScore(MultiScaleLoss):
102
    """
103
    Define dice score.
104
105
    The formulation is:
106
107
        0. w_fg + w_bg = 1
108
        1. let y_prod = y_true * y_pred and y_sum  = y_true + y_pred
109
        2. num = 2 *  (w_fg * y_true * y_pred + w_bg * (1−y_true) * (1−y_pred))
110
               = 2 *  ((w_fg+w_bg) * y_prod - w_bg * y_sum + w_bg)
111
               = 2 *  (y_prod - w_bg * y_sum + w_bg)
112
        3. denom = (w_fg * (y_true + y_pred) + w_bg * (1−y_true + 1−y_pred))
113
                 = (w_fg-w_bg) * y_sum + 2 * w_bg
114
                 = (1-2*w_bg) * y_sum + 2 * w_bg
115
        4. dice score = num / denom
116
117
    where num and denom are summed over all axes except the batch axis.
118
119
    Reference:
120
        Sudre, Carole H., et al. "Generalised dice overlap as a deep learning loss
121
        function for highly unbalanced segmentations." Deep learning in medical image
122
        analysis and multimodal learning for clinical decision support.
123
        Springer, Cham, 2017. 240-248.
124
    """
125
126
    def __init__(
127
        self,
128
        binary: bool = False,
129
        background_weight: float = 0.0,
130
        smooth_nr: float = EPS,
131
        smooth_dr: float = EPS,
132
        scales: Optional[List] = None,
133
        kernel: str = "gaussian",
134
        reduction: str = tf.keras.losses.Reduction.SUM,
135
        name: str = "DiceScore",
136
    ):
137
        """
138
        Init.
139
140
        :param binary: if True, project y_true, y_pred to 0 or 1.
141
        :param background_weight: weight for background, where y == 0.
142
        :param smooth_nr: small constant added to numerator in case of zero covariance.
143
        :param smooth_dr: small constant added to denominator in case of zero variance.
144
        :param scales: list of scalars or None, if None, do not apply any scaling.
145
        :param kernel: gaussian or cauchy.
146
        :param reduction: using SUM reduction over batch axis,
147
            this is for supporting multi-device training,
148
            and the loss will be divided by global batch size,
149
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
150
        :param name: str, name of the loss.
151
        """
152
        super().__init__(scales=scales, kernel=kernel, reduction=reduction, name=name)
153
        if background_weight < 0 or background_weight > 1:
154
            raise ValueError(
155
                "The background weight for Dice Score must be "
156
                f"within [0, 1], got {background_weight}."
157
            )
158
159
        self.binary = binary
160
        self.background_weight = background_weight
161
        self.smooth_nr = smooth_nr
162
        self.smooth_dr = smooth_dr
163
        self.flatten = tf.keras.layers.Flatten()
164
165
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
166
        """
167
        Return loss for a batch.
168
169
        :param y_true: shape = (batch, ...)
170
        :param y_pred: shape = (batch, ...)
171
        :return: shape = (batch,)
172
        """
173
        if self.binary:
174
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
175
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
176
177
        # (batch, ...) -> (batch, d)
178
        y_true = self.flatten(y_true)
179
        y_pred = self.flatten(y_pred)
180
181
        # for foreground class
182
        y_prod = tf.reduce_sum(y_true * y_pred, axis=1)
183
        y_sum = tf.reduce_sum(y_true + y_pred, axis=1)
184
185
        if self.background_weight > 0:
186
            # generalized
187
            vol = tf.reduce_sum(tf.ones_like(y_true), axis=1)
188
            numerator = 2 * (
189
                y_prod - self.background_weight * y_sum + self.background_weight * vol
190
            )
191
            denominator = (
192
                1 - 2 * self.background_weight
193
            ) * y_sum + 2 * self.background_weight * vol
194
        else:
195
            # foreground only
196
            numerator = 2 * y_prod
197
            denominator = y_sum
198
199
        return (numerator + self.smooth_nr) / (denominator + self.smooth_dr)
200
201
    def get_config(self) -> dict:
202
        """Return the config dictionary for recreating this class."""
203
        config = super().get_config()
204
        config.update(
205
            binary=self.binary,
206
            background_weight=self.background_weight,
207
            smooth_nr=self.smooth_nr,
208
            smooth_dr=self.smooth_dr,
209
        )
210
        return config
211
212
213
@REGISTRY.register_loss(name="dice")
214
class DiceLoss(NegativeLossMixin, DiceScore):
215
    """Revert the sign of DiceScore."""
216
217
218
@REGISTRY.register_loss(name="cross-entropy")
219
class CrossEntropy(MultiScaleLoss):
220
    """
221
    Define weighted cross-entropy.
222
223
    The formulation is:
224
        loss = − w_fg * y_true log(y_pred) - w_bg * (1−y_true) log(1−y_pred)
225
    """
226
227
    def __init__(
228
        self,
229
        binary: bool = False,
230
        background_weight: float = 0.0,
231
        smooth: float = EPS,
232
        scales: Optional[List] = None,
233
        kernel: str = "gaussian",
234
        reduction: str = tf.keras.losses.Reduction.SUM,
235
        name: str = "CrossEntropy",
236
    ):
237
        """
238
        Init.
239
240
        :param binary: if True, project y_true, y_pred to 0 or 1
241
        :param background_weight: weight for background, where y == 0.
242
        :param scales: list of scalars or None, if None, do not apply any scaling.
243
        :param kernel: gaussian or cauchy.
244
        :param reduction: using SUM reduction over batch axis,
245
            this is for supporting multi-device training,
246
            and the loss will be divided by global batch size,
247
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
248
        :param name: str, name of the loss.
249
        """
250
        super().__init__(scales=scales, kernel=kernel, reduction=reduction, name=name)
251
        assert 0 <= background_weight <= 1
252
        self.binary = binary
253
        self.background_weight = background_weight
254
        self.smooth = smooth
255
        self.flatten = tf.keras.layers.Flatten()
256
257
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
258
        """
259
        Return loss for a batch.
260
261
        :param y_true: shape = (batch, ...)
262
        :param y_pred: shape = (batch, ...)
263
        :return: shape = (batch,)
264
        """
265
        if self.binary:
266
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
267
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
268
269
        # (batch, ...) -> (batch, d)
270
        y_true = self.flatten(y_true)
271
        y_pred = self.flatten(y_pred)
272
273
        loss_fg = -tf.reduce_mean(y_true * tf.math.log(y_pred + self.smooth), axis=1)
274
        if self.background_weight > 0:
275
            loss_bg = -tf.reduce_mean(
276
                (1 - y_true) * tf.math.log(1 - y_pred + self.smooth), axis=1
277
            )
278
            return (
279
                1 - self.background_weight
280
            ) * loss_fg + self.background_weight * loss_bg
281
        else:
282
            return loss_fg
283
284
    def get_config(self) -> dict:
285
        """Return the config dictionary for recreating this class."""
286
        config = super().get_config()
287
        config.update(
288
            binary=self.binary,
289
            background_weight=self.background_weight,
290
            smooth=self.smooth,
291
        )
292
        return config
293
294
295
class JaccardIndex(DiceScore):
296
    """
297
    Define Jaccard index.
298
299
    The formulation is:
300
    1. num = y_true * y_pred
301
    2. denom = y_true + y_pred - y_true * y_pred
302
    3. Jaccard index = num / denom
303
304
        0. w_fg + w_bg = 1
305
        1. let y_prod = y_true * y_pred and y_sum  = y_true + y_pred
306
        2. num = (w_fg * y_true * y_pred + w_bg * (1−y_true) * (1−y_pred))
307
               = ((w_fg+w_bg) * y_prod - w_bg * y_sum + w_bg)
308
               = (y_prod - w_bg * y_sum + w_bg)
309
        3. denom = (w_fg * (y_true + y_pred - y_true * y_pred)
310
                  + w_bg * (1−y_true + 1−y_pred - (1−y_true) * (1−y_pred)))
311
                 = w_fg * (y_sum - y_prod) + w_bg * (1-y_prod)
312
                 = (1-w_bg) * y_sum - y_prod + w_bg
313
        4. dice score = num / denom
314
    """
315
316
    def __init__(
317
        self,
318
        binary: bool = False,
319
        background_weight: float = 0.0,
320
        smooth_nr: float = EPS,
321
        smooth_dr: float = EPS,
322
        scales: Optional[List] = None,
323
        kernel: str = "gaussian",
324
        reduction: str = tf.keras.losses.Reduction.SUM,
325
        name: str = "JaccardIndex",
326
    ):
327
        """
328
        Init.
329
330
        :param binary: if True, project y_true, y_pred to 0 or 1.
331
        :param background_weight: weight for background, where y == 0.
332
        :param smooth_nr: small constant added to numerator in case of zero covariance.
333
        :param smooth_dr: small constant added to denominator in case of zero variance.
334
        :param scales: list of scalars or None, if None, do not apply any scaling.
335
        :param kernel: gaussian or cauchy.
336
        :param reduction: using SUM reduction over batch axis,
337
            this is for supporting multi-device training,
338
            and the loss will be divided by global batch size,
339
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
340
        :param name: str, name of the loss.
341
        """
342
        super().__init__(
343
            binary=binary,
344
            background_weight=background_weight,
345
            smooth_nr=smooth_nr,
346
            smooth_dr=smooth_dr,
347
            scales=scales,
348
            kernel=kernel,
349
            reduction=reduction,
350
            name=name,
351
        )
352
353
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
354
        """
355
        Return loss for a batch.
356
357
        :param y_true: shape = (batch, ...)
358
        :param y_pred: shape = (batch, ...)
359
        :return: shape = (batch,)
360
        """
361
        if self.binary:
362
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
363
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
364
365
        # (batch, ...) -> (batch, d)
366
        y_true = self.flatten(y_true)
367
        y_pred = self.flatten(y_pred)
368
369
        # for foreground class
370
        y_prod = tf.reduce_sum(y_true * y_pred, axis=1)
371
        y_sum = tf.reduce_sum(y_true + y_pred, axis=1)
372
373
        if self.background_weight > 0:
374
            # generalized
375
            vol = tf.reduce_sum(tf.ones_like(y_true), axis=1)
376
            numerator = (
377
                y_prod - self.background_weight * y_sum + self.background_weight * vol
378
            )
379
            denominator = (
380
                (1 - self.background_weight) * y_sum
381
                - y_prod
382
                + self.background_weight * vol
383
            )
384
        else:
385
            # foreground only
386
            numerator = y_prod
387
            denominator = y_sum - y_prod
388
389
        return (numerator + self.smooth_nr) / (denominator + self.smooth_dr)
390
391
392
@REGISTRY.register_loss(name="jaccard")
393
class JaccardLoss(NegativeLossMixin, JaccardIndex):
394
    """Revert the sign of JaccardIndex."""
395
396
397
def compute_centroid(mask: tf.Tensor, grid: tf.Tensor) -> tf.Tensor:
398
    """
399
    Calculate the centroid of the mask.
400
    :param mask: shape = (batch, dim1, dim2, dim3)
401
    :param grid: shape = (1, dim1, dim2, dim3, 3)
402
    :return: shape = (batch, 3), batch of vectors denoting
403
             location of centroids.
404
    """
405
    assert len(mask.shape) == 4
406
    assert len(grid.shape) == 5
407
    bool_mask = tf.expand_dims(
408
        tf.cast(mask >= 0.5, dtype=tf.float32), axis=4
409
    )  # (batch, dim1, dim2, dim3, 1)
410
    masked_grid = bool_mask * grid  # (batch, dim1, dim2, dim3, 3)
411
    numerator = tf.reduce_sum(masked_grid, axis=[1, 2, 3])  # (batch, 3)
412
    denominator = tf.reduce_sum(bool_mask, axis=[1, 2, 3])  # (batch, 1)
413
    return (numerator + EPS) / (denominator + EPS)  # (batch, 3)
414
415
416
def compute_centroid_distance(
417
    y_true: tf.Tensor, y_pred: tf.Tensor, grid: tf.Tensor
418
) -> tf.Tensor:
419
    """
420
    Calculate the L2-distance between two tensors' centroids.
421
    :param y_true: tensor, shape = (batch, dim1, dim2, dim3)
422
    :param y_pred: tensor, shape = (batch, dim1, dim2, dim3)
423
    :param grid: tensor, shape = (1, dim1, dim2, dim3, 3)
424
    :return: shape = (batch,)
425
    """
426
    centroid_1 = compute_centroid(mask=y_pred, grid=grid)  # (batch, 3)
427
    centroid_2 = compute_centroid(mask=y_true, grid=grid)  # (batch, 3)
428
    return tf.sqrt(tf.reduce_sum((centroid_1 - centroid_2) ** 2, axis=1))
429
430
431
def foreground_proportion(y: tf.Tensor) -> tf.Tensor:
432
    """
433
    Calculate the percentage of foreground vs background per 3d volume.
434
    :param y: shape = (batch, dim1, dim2, dim3), a 3D label tensor
435
    :return: shape = (batch,)
436
    """
437
    y = tf.cast(y >= 0.5, dtype=tf.float32)
438
    return tf.reduce_sum(y, axis=[1, 2, 3]) / tf.reduce_sum(
439
        tf.ones_like(y), axis=[1, 2, 3]
440
    )
441