Completed
Push — main ( bbe77b...232bc8 )
by Yunguan
22s queued 13s
created

deepreg.loss.label.CrossEntropy.__init__()   A

Complexity

Conditions 3

Size

Total Lines 33
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 17
dl 0
loc 33
rs 9.55
c 0
b 0
f 0
cc 3
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""Provide different loss or metrics classes for labels."""
2
3
from typing import List, Optional
4
5
import tensorflow as tf
6
7
from deepreg.constant import EPS
8
from deepreg.loss.util import NegativeLossMixin, cauchy_kernel1d
9
from deepreg.loss.util import gaussian_kernel1d_sigma as gaussian_kernel1d
10
from deepreg.loss.util import separable_filter
11
from deepreg.registry import REGISTRY
12
13
14
class MultiScaleLoss(tf.keras.losses.Loss):
15
    """
16
    Base class for multi-scale loss.
17
18
    It applies the loss at different scales (gaussian or cauchy smoothing).
19
    It is assumed that loss values are between 0 and 1.
20
    """
21
22
    kernel_fn_dict = dict(gaussian=gaussian_kernel1d, cauchy=cauchy_kernel1d)
23
24
    def __init__(
25
        self,
26
        scales: Optional[List] = None,
27
        kernel: str = "gaussian",
28
        reduction: str = tf.keras.losses.Reduction.NONE,
29
        name: str = "MultiScaleLoss",
30
    ):
31
        """
32
        Init.
33
34
        :param scales: list of scalars or None, if None, do not apply any scaling.
35
        :param kernel: gaussian or cauchy.
36
        :param reduction: do not perform reduction over batch axis.
37
            this is for supporting multi-device training,
38
            model.fit() will average over global batch size automatically.
39
            Loss returns a tensor of shape (batch, ).
40
        :param name: str, name of the loss.
41
        """
42
        super().__init__(reduction=reduction, name=name)
43
        assert kernel in ["gaussian", "cauchy"]
44
        self.scales = scales
45
        self.kernel = kernel
46
47
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
48
        """
49
        Use _call to calculate loss at different scales.
50
51
        :param y_true: ground-truth tensor, shape = (batch, dim1, dim2, dim3).
52
        :param y_pred: predicted tensor, shape = (batch, dim1, dim2, dim3).
53
        :return: multi-scale loss, shape = (batch, ).
54
        """
55
        if self.scales is None:
56
            return self._call(y_true=y_true, y_pred=y_pred)
57
        kernel_fn = self.kernel_fn_dict[self.kernel]
58
        losses = []
59
        for s in self.scales:
60
            if s == 0:
61
                # no smoothing
62
                losses.append(
63
                    self._call(
64
                        y_true=y_true,
65
                        y_pred=y_pred,
66
                    )
67
                )
68
            else:
69
                losses.append(
70
                    self._call(
71
                        y_true=separable_filter(
72
                            tf.expand_dims(y_true, axis=4), kernel_fn(s)
73
                        )[..., 0],
74
                        y_pred=separable_filter(
75
                            tf.expand_dims(y_pred, axis=4), kernel_fn(s)
76
                        )[..., 0],
77
                    )
78
                )
79
        loss = tf.add_n(losses)
80
        loss = loss / len(self.scales)
81
        return loss
82
83
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
84
        """
85
        Return loss for a batch.
86
87
        :param y_true: ground-truth tensor.
88
        :param y_pred: predicted tensor.
89
        :return: negated loss.
90
        """
91
        raise NotImplementedError
92
93
    def get_config(self) -> dict:
94
        """Return the config dictionary for recreating this class."""
95
        config = super().get_config()
96
        config["scales"] = self.scales
97
        config["kernel"] = self.kernel
98
        return config
99
100
101
class DiceScore(MultiScaleLoss):
102
    """
103
    Define dice score.
104
105
    The formulation is:
106
107
        0. w_fg + w_bg = 1
108
        1. let y_prod = y_true * y_pred and y_sum  = y_true + y_pred
109
        2. num = 2 *  (w_fg * y_true * y_pred + w_bg * (1−y_true) * (1−y_pred))
110
               = 2 *  ((w_fg+w_bg) * y_prod - w_bg * y_sum + w_bg)
111
               = 2 *  (y_prod - w_bg * y_sum + w_bg)
112
        3. denom = (w_fg * (y_true + y_pred) + w_bg * (1−y_true + 1−y_pred))
113
                 = (w_fg-w_bg) * y_sum + 2 * w_bg
114
                 = (1-2*w_bg) * y_sum + 2 * w_bg
115
        4. dice score = num / denom
116
117
    where num and denom are summed over all axes except the batch axis.
118
119
    Reference:
120
        Sudre, Carole H., et al. "Generalised dice overlap as a deep learning loss
121
        function for highly unbalanced segmentations." Deep learning in medical image
122
        analysis and multimodal learning for clinical decision support.
123
        Springer, Cham, 2017. 240-248.
124
    """
125
126
    def __init__(
127
        self,
128
        binary: bool = False,
129
        background_weight: float = 0.0,
130
        smooth_nr: float = EPS,
131
        smooth_dr: float = EPS,
132
        scales: Optional[List] = None,
133
        kernel: str = "gaussian",
134
        reduction: str = tf.keras.losses.Reduction.NONE,
135
        name: str = "DiceScore",
136
    ):
137
        """
138
        Init.
139
140
        :param binary: if True, project y_true, y_pred to 0 or 1.
141
        :param background_weight: weight for background, where y == 0.
142
        :param smooth_nr: small constant added to numerator in case of zero covariance.
143
        :param smooth_dr: small constant added to denominator in case of zero variance.
144
        :param scales: list of scalars or None, if None, do not apply any scaling.
145
        :param kernel: gaussian or cauchy.
146
        :param reduction: do not perform reduction over batch axis.
147
            this is for supporting multi-device training,
148
            model.fit() will average over global batch size automatically.
149
            Loss returns a tensor of shape (batch, ).
150
        :param name: str, name of the loss.
151
        """
152
        super().__init__(scales=scales, kernel=kernel, reduction=reduction, name=name)
153
        if background_weight < 0 or background_weight > 1:
154
            raise ValueError(
155
                "The background weight for Dice Score must be "
156
                f"within [0, 1], got {background_weight}."
157
            )
158
159
        self.binary = binary
160
        self.background_weight = background_weight
161
        self.smooth_nr = smooth_nr
162
        self.smooth_dr = smooth_dr
163
        self.flatten = tf.keras.layers.Flatten()
164
165
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
166
        """
167
        Return loss for a batch.
168
169
        :param y_true: shape = (batch, ...)
170
        :param y_pred: shape = (batch, ...)
171
        :return: shape = (batch,)
172
        """
173
        if self.binary:
174
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
175
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
176
177
        # (batch, ...) -> (batch, d)
178
        y_true = self.flatten(y_true)
179
        y_pred = self.flatten(y_pred)
180
181
        # for foreground class
182
        y_prod = tf.reduce_sum(y_true * y_pred, axis=1)
183
        y_sum = tf.reduce_sum(y_true + y_pred, axis=1)
184
185
        if self.background_weight > 0:
186
            # generalized
187
            vol = tf.reduce_sum(tf.ones_like(y_true), axis=1)
188
            numerator = 2 * (
189
                y_prod - self.background_weight * y_sum + self.background_weight * vol
190
            )
191
            denominator = (
192
                1 - 2 * self.background_weight
193
            ) * y_sum + 2 * self.background_weight * vol
194
        else:
195
            # foreground only
196
            numerator = 2 * y_prod
197
            denominator = y_sum
198
199
        return (numerator + self.smooth_nr) / (denominator + self.smooth_dr)
200
201
    def get_config(self) -> dict:
202
        """Return the config dictionary for recreating this class."""
203
        config = super().get_config()
204
        config.update(
205
            binary=self.binary,
206
            background_weight=self.background_weight,
207
            smooth_nr=self.smooth_nr,
208
            smooth_dr=self.smooth_dr,
209
        )
210
        return config
211
212
213
@REGISTRY.register_loss(name="dice")
214
class DiceLoss(NegativeLossMixin, DiceScore):
215
    """Revert the sign of DiceScore."""
216
217
218
@REGISTRY.register_loss(name="cross-entropy")
219
class CrossEntropy(MultiScaleLoss):
220
    """
221
    Define weighted cross-entropy.
222
223
    The formulation is:
224
        loss = − w_fg * y_true log(y_pred) - w_bg * (1−y_true) log(1−y_pred)
225
    """
226
227
    def __init__(
228
        self,
229
        binary: bool = False,
230
        background_weight: float = 0.0,
231
        smooth: float = EPS,
232
        scales: Optional[List] = None,
233
        kernel: str = "gaussian",
234
        reduction: str = tf.keras.losses.Reduction.NONE,
235
        name: str = "CrossEntropy",
236
    ):
237
        """
238
        Init.
239
240
        :param binary: if True, project y_true, y_pred to 0 or 1
241
        :param background_weight: weight for background, where y == 0.
242
        :param scales: list of scalars or None, if None, do not apply any scaling.
243
        :param kernel: gaussian or cauchy.
244
        :param reduction: do not perform reduction over batch axis.
245
            this is for supporting multi-device training,
246
            model.fit() will average over global batch size automatically.
247
            Loss returns a tensor of shape (batch, ).
248
        :param name: str, name of the loss.
249
        """
250
        super().__init__(scales=scales, kernel=kernel, reduction=reduction, name=name)
251
        if background_weight < 0 or background_weight > 1:
252
            raise ValueError(
253
                "The background weight for Cross Entropy must be "
254
                f"within [0, 1], got {background_weight}."
255
            )
256
        self.binary = binary
257
        self.background_weight = background_weight
258
        self.smooth = smooth
259
        self.flatten = tf.keras.layers.Flatten()
260
261
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
262
        """
263
        Return loss for a batch.
264
265
        :param y_true: shape = (batch, ...)
266
        :param y_pred: shape = (batch, ...)
267
        :return: shape = (batch,)
268
        """
269
        if self.binary:
270
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
271
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
272
273
        # (batch, ...) -> (batch, d)
274
        y_true = self.flatten(y_true)
275
        y_pred = self.flatten(y_pred)
276
277
        loss_fg = -tf.reduce_mean(y_true * tf.math.log(y_pred + self.smooth), axis=1)
278
        if self.background_weight > 0:
279
            loss_bg = -tf.reduce_mean(
280
                (1 - y_true) * tf.math.log(1 - y_pred + self.smooth), axis=1
281
            )
282
            return (
283
                1 - self.background_weight
284
            ) * loss_fg + self.background_weight * loss_bg
285
        else:
286
            return loss_fg
287
288
    def get_config(self) -> dict:
289
        """Return the config dictionary for recreating this class."""
290
        config = super().get_config()
291
        config.update(
292
            binary=self.binary,
293
            background_weight=self.background_weight,
294
            smooth=self.smooth,
295
        )
296
        return config
297
298
299
class JaccardIndex(DiceScore):
300
    """
301
    Define Jaccard index.
302
303
    The formulation is:
304
    1. num = y_true * y_pred
305
    2. denom = y_true + y_pred - y_true * y_pred
306
    3. Jaccard index = num / denom
307
308
        0. w_fg + w_bg = 1
309
        1. let y_prod = y_true * y_pred and y_sum  = y_true + y_pred
310
        2. num = (w_fg * y_true * y_pred + w_bg * (1−y_true) * (1−y_pred))
311
               = ((w_fg+w_bg) * y_prod - w_bg * y_sum + w_bg)
312
               = (y_prod - w_bg * y_sum + w_bg)
313
        3. denom = (w_fg * (y_true + y_pred - y_true * y_pred)
314
                  + w_bg * (1−y_true + 1−y_pred - (1−y_true) * (1−y_pred)))
315
                 = w_fg * (y_sum - y_prod) + w_bg * (1-y_prod)
316
                 = (1-w_bg) * y_sum - y_prod + w_bg
317
        4. dice score = num / denom
318
    """
319
320
    def __init__(
321
        self,
322
        binary: bool = False,
323
        background_weight: float = 0.0,
324
        smooth_nr: float = EPS,
325
        smooth_dr: float = EPS,
326
        scales: Optional[List] = None,
327
        kernel: str = "gaussian",
328
        reduction: str = tf.keras.losses.Reduction.NONE,
329
        name: str = "JaccardIndex",
330
    ):
331
        """
332
        Init.
333
334
        :param binary: if True, project y_true, y_pred to 0 or 1.
335
        :param background_weight: weight for background, where y == 0.
336
        :param smooth_nr: small constant added to numerator in case of zero covariance.
337
        :param smooth_dr: small constant added to denominator in case of zero variance.
338
        :param scales: list of scalars or None, if None, do not apply any scaling.
339
        :param kernel: gaussian or cauchy.
340
        :param reduction: do not perform reduction over batch axis.
341
            this is for supporting multi-device training,
342
            model.fit() will average over global batch size automatically.
343
            Loss returns a tensor of shape (batch, ).
344
        :param name: str, name of the loss.
345
        """
346
        super().__init__(
347
            binary=binary,
348
            background_weight=background_weight,
349
            smooth_nr=smooth_nr,
350
            smooth_dr=smooth_dr,
351
            scales=scales,
352
            kernel=kernel,
353
            reduction=reduction,
354
            name=name,
355
        )
356
357
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
358
        """
359
        Return loss for a batch.
360
361
        :param y_true: shape = (batch, ...)
362
        :param y_pred: shape = (batch, ...)
363
        :return: shape = (batch,)
364
        """
365
        if self.binary:
366
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
367
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
368
369
        # (batch, ...) -> (batch, d)
370
        y_true = self.flatten(y_true)
371
        y_pred = self.flatten(y_pred)
372
373
        # for foreground class
374
        y_prod = tf.reduce_sum(y_true * y_pred, axis=1)
375
        y_sum = tf.reduce_sum(y_true + y_pred, axis=1)
376
377
        if self.background_weight > 0:
378
            # generalized
379
            vol = tf.reduce_sum(tf.ones_like(y_true), axis=1)
380
            numerator = (
381
                y_prod - self.background_weight * y_sum + self.background_weight * vol
382
            )
383
            denominator = (
384
                (1 - self.background_weight) * y_sum
385
                - y_prod
386
                + self.background_weight * vol
387
            )
388
        else:
389
            # foreground only
390
            numerator = y_prod
391
            denominator = y_sum - y_prod
392
393
        return (numerator + self.smooth_nr) / (denominator + self.smooth_dr)
394
395
396
@REGISTRY.register_loss(name="jaccard")
397
class JaccardLoss(NegativeLossMixin, JaccardIndex):
398
    """Revert the sign of JaccardIndex."""
399
400
401
def compute_centroid(mask: tf.Tensor, grid: tf.Tensor) -> tf.Tensor:
402
    """
403
    Calculate the centroid of the mask.
404
    :param mask: shape = (batch, dim1, dim2, dim3)
405
    :param grid: shape = (1, dim1, dim2, dim3, 3)
406
    :return: shape = (batch, 3), batch of vectors denoting
407
             location of centroids.
408
    """
409
    assert len(mask.shape) == 4
410
    assert len(grid.shape) == 5
411
    bool_mask = tf.expand_dims(
412
        tf.cast(mask >= 0.5, dtype=tf.float32), axis=4
413
    )  # (batch, dim1, dim2, dim3, 1)
414
    masked_grid = bool_mask * grid  # (batch, dim1, dim2, dim3, 3)
415
    numerator = tf.reduce_sum(masked_grid, axis=[1, 2, 3])  # (batch, 3)
416
    denominator = tf.reduce_sum(bool_mask, axis=[1, 2, 3])  # (batch, 1)
417
    return (numerator + EPS) / (denominator + EPS)  # (batch, 3)
418
419
420
def compute_centroid_distance(
421
    y_true: tf.Tensor, y_pred: tf.Tensor, grid: tf.Tensor
422
) -> tf.Tensor:
423
    """
424
    Calculate the L2-distance between two tensors' centroids.
425
    :param y_true: tensor, shape = (batch, dim1, dim2, dim3)
426
    :param y_pred: tensor, shape = (batch, dim1, dim2, dim3)
427
    :param grid: tensor, shape = (1, dim1, dim2, dim3, 3)
428
    :return: shape = (batch,)
429
    """
430
    centroid_1 = compute_centroid(mask=y_pred, grid=grid)  # (batch, 3)
431
    centroid_2 = compute_centroid(mask=y_true, grid=grid)  # (batch, 3)
432
    return tf.sqrt(tf.reduce_sum((centroid_1 - centroid_2) ** 2, axis=1))
433
434
435
def foreground_proportion(y: tf.Tensor) -> tf.Tensor:
436
    """
437
    Calculate the percentage of foreground vs background per 3d volume.
438
    :param y: shape = (batch, dim1, dim2, dim3), a 3D label tensor
439
    :return: shape = (batch,)
440
    """
441
    y = tf.cast(y >= 0.5, dtype=tf.float32)
442
    return tf.reduce_sum(y, axis=[1, 2, 3]) / tf.reduce_sum(
443
        tf.ones_like(y), axis=[1, 2, 3]
444
    )
445