Passed
Pull Request — main (#736)
by Yunguan
01:30
created

deepreg.loss.label.JaccardIndex.__init__()   A

Complexity

Conditions 1

Size

Total Lines 35
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 19
dl 0
loc 35
rs 9.45
c 0
b 0
f 0
cc 1
nop 9

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""Provide different loss or metrics classes for labels."""
2
3
from typing import List, Optional
4
5
import tensorflow as tf
6
7
from deepreg.constant import EPS
8
from deepreg.loss.util import NegativeLossMixin, cauchy_kernel1d
9
from deepreg.loss.util import gaussian_kernel1d_sigma as gaussian_kernel1d
10
from deepreg.loss.util import separable_filter
11
from deepreg.registry import REGISTRY
12
13
14
class MultiScaleLoss(tf.keras.losses.Loss):
15
    """
16
    Base class for multi-scale loss.
17
18
    It applies the loss at different scales (gaussian or cauchy smoothing).
19
    It is assumed that loss values are between 0 and 1.
20
    """
21
22
    kernel_fn_dict = dict(gaussian=gaussian_kernel1d, cauchy=cauchy_kernel1d)
23
24
    def __init__(
25
        self,
26
        scales: Optional[List] = None,
27
        kernel: str = "gaussian",
28
        reduction: str = tf.keras.losses.Reduction.SUM,
29
        name: str = "MultiScaleLoss",
30
    ):
31
        """
32
        Init.
33
34
        :param scales: list of scalars or None, if None, do not apply any scaling.
35
        :param kernel: gaussian or cauchy.
36
        :param reduction: using SUM reduction over batch axis,
37
            this is for supporting multi-device training,
38
            and the loss will be divided by global batch size,
39
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
40
        :param name: str, name of the loss.
41
        """
42
        super().__init__(reduction=reduction, name=name)
43
        assert kernel in ["gaussian", "cauchy"]
44
        self.scales = scales
45
        self.kernel = kernel
46
47
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
48
        """
49
        Use _call to calculate loss at different scales.
50
51
        :param y_true: ground-truth tensor, shape = (batch, dim1, dim2, dim3).
52
        :param y_pred: predicted tensor, shape = (batch, dim1, dim2, dim3).
53
        :return: multi-scale loss, shape = (batch, ).
54
        """
55
        if self.scales is None:
56
            return self._call(y_true=y_true, y_pred=y_pred)
57
        kernel_fn = self.kernel_fn_dict[self.kernel]
58
        losses = []
59
        for s in self.scales:
60
            if s == 0:
61
                # no smoothing
62
                losses.append(
63
                    self._call(
64
                        y_true=y_true,
65
                        y_pred=y_pred,
66
                    )
67
                )
68
            else:
69
                losses.append(
70
                    self._call(
71
                        y_true=separable_filter(
72
                            tf.expand_dims(y_true, axis=4), kernel_fn(s)
73
                        )[..., 0],
74
                        y_pred=separable_filter(
75
                            tf.expand_dims(y_pred, axis=4), kernel_fn(s)
76
                        )[..., 0],
77
                    )
78
                )
79
        loss = tf.add_n(losses)
80
        loss = loss / len(self.scales)
81
        return loss
82
83
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
84
        """
85
        Return loss for a batch.
86
87
        :param y_true: ground-truth tensor.
88
        :param y_pred: predicted tensor.
89
        :return: negated loss.
90
        """
91
        raise NotImplementedError
92
93
    def get_config(self) -> dict:
94
        """Return the config dictionary for recreating this class."""
95
        config = super().get_config()
96
        config["scales"] = self.scales
97
        config["kernel"] = self.kernel
98
        return config
99
100
101
class DiceScore(MultiScaleLoss):
102
    """
103
    Define dice score.
104
105
    The formulation is:
106
107
        0. w_fg + w_bg = 1
108
        1. let y_prod = y_true * y_pred and y_sum  = y_true + y_pred
109
        2. num = 2 *  (w_fg * y_true * y_pred + w_bg * (1−y_true) * (1−y_pred))
110
               = 2 *  ((w_fg+w_bg) * y_prod - w_bg * y_sum + w_bg)
111
               = 2 *  (y_prod - w_bg * y_sum + w_bg)
112
        3. denom = (w_fg * (y_true + y_pred) + w_bg * (1−y_true + 1−y_pred))
113
                 = (w_fg-w_bg) * y_sum + 2 * w_bg
114
                 = (1-2*w_bg) * y_sum + 2 * w_bg
115
        4. dice score = num / denom
116
117
    where num and denom are summed over all axes except the batch axis.
118
119
    Reference:
120
        Sudre, Carole H., et al. "Generalised dice overlap as a deep learning loss
121
        function for highly unbalanced segmentations." Deep learning in medical image
122
        analysis and multimodal learning for clinical decision support.
123
        Springer, Cham, 2017. 240-248.
124
    """
125
126
    def __init__(
127
        self,
128
        binary: bool = False,
129
        background_weight: float = 0.0,
130
        smooth_nr: float = EPS,
131
        smooth_dr: float = EPS,
132
        scales: Optional[List] = None,
133
        kernel: str = "gaussian",
134
        reduction: str = tf.keras.losses.Reduction.SUM,
135
        name: str = "DiceScore",
136
    ):
137
        """
138
        Init.
139
140
        :param binary: if True, project y_true, y_pred to 0 or 1.
141
        :param background_weight: weight for background, where y == 0.
142
        :param smooth_nr: small constant added to numerator in case of zero covariance.
143
        :param smooth_dr: small constant added to denominator in case of zero variance.
144
        :param scales: list of scalars or None, if None, do not apply any scaling.
145
        :param kernel: gaussian or cauchy.
146
        :param reduction: using SUM reduction over batch axis,
147
            this is for supporting multi-device training,
148
            and the loss will be divided by global batch size,
149
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
150
        :param name: str, name of the loss.
151
        """
152
        super().__init__(scales=scales, kernel=kernel, reduction=reduction, name=name)
153
        if background_weight < 0 or background_weight > 1:
154
            raise ValueError(
155
                "The background weight for Dice Score must be "
156
                f"within [0, 1], got {background_weight}."
157
            )
158
159
        self.binary = binary
160
        self.background_weight = background_weight
161
        self.smooth_nr = smooth_nr
162
        self.smooth_dr = smooth_dr
163
        self.flatten = tf.keras.layers.Flatten()
164
165
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
166
        """
167
        Return loss for a batch.
168
169
        :param y_true: shape = (batch, ...)
170
        :param y_pred: shape = (batch, ...)
171
        :return: shape = (batch,)
172
        """
173
        if self.binary:
174
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
175
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
176
177
        # (batch, ...) -> (batch, d)
178
        y_true = self.flatten(y_true)
179
        y_pred = self.flatten(y_pred)
180
181
        # for foreground class
182
        y_prod = tf.reduce_sum(y_true * y_pred, axis=1)
183
        y_sum = tf.reduce_sum(y_true + y_pred, axis=1)
184
185
        if self.background_weight > 0:
186
            # generalized
187
            vol = tf.reduce_sum(tf.ones_like(y_true), axis=1)
188
            numerator = 2 * (
189
                y_prod - self.background_weight * y_sum + self.background_weight * vol
190
            )
191
            denominator = (
192
                1 - 2 * self.background_weight
193
            ) * y_sum + 2 * self.background_weight * vol
194
        else:
195
            # foreground only
196
            numerator = 2 * y_prod
197
            denominator = y_sum
198
199
        return (numerator + self.smooth_nr) / (denominator + self.smooth_dr)
200
201
    def get_config(self) -> dict:
202
        """Return the config dictionary for recreating this class."""
203
        config = super().get_config()
204
        config.update(
205
            binary=self.binary,
206
            background_weight=self.background_weight,
207
            smooth_nr=self.smooth_nr,
208
            smooth_dr=self.smooth_dr,
209
        )
210
        return config
211
212
213
@REGISTRY.register_loss(name="dice")
214
class DiceLoss(NegativeLossMixin, DiceScore):
215
    """Revert the sign of DiceScore."""
216
217
218
@REGISTRY.register_loss(name="cross-entropy")
219
class CrossEntropy(MultiScaleLoss):
220
    """
221
    Define weighted cross-entropy.
222
223
    The formulation is:
224
        loss = − w_fg * y_true log(y_pred) - w_bg * (1−y_true) log(1−y_pred)
225
    """
226
227
    def __init__(
228
        self,
229
        binary: bool = False,
230
        background_weight: float = 0.0,
231
        smooth: float = EPS,
232
        scales: Optional[List] = None,
233
        kernel: str = "gaussian",
234
        reduction: str = tf.keras.losses.Reduction.SUM,
235
        name: str = "CrossEntropy",
236
    ):
237
        """
238
        Init.
239
240
        :param binary: if True, project y_true, y_pred to 0 or 1
241
        :param background_weight: weight for background, where y == 0.
242
        :param scales: list of scalars or None, if None, do not apply any scaling.
243
        :param kernel: gaussian or cauchy.
244
        :param reduction: using SUM reduction over batch axis,
245
            this is for supporting multi-device training,
246
            and the loss will be divided by global batch size,
247
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
248
        :param name: str, name of the loss.
249
        """
250
        super().__init__(scales=scales, kernel=kernel, reduction=reduction, name=name)
251
        assert 0 <= background_weight <= 1
252
        self.binary = binary
253
        self.background_weight = background_weight
254
        self.smooth = smooth
255
        self.flatten = tf.keras.layers.Flatten()
256
257
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
258
        """
259
        Return loss for a batch.
260
261
        :param y_true: shape = (batch, ...)
262
        :param y_pred: shape = (batch, ...)
263
        :return: shape = (batch,)
264
        """
265
        if self.binary:
266
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
267
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
268
269
        # (batch, ...) -> (batch, d)
270
        y_true = self.flatten(y_true)
271
        y_pred = self.flatten(y_pred)
272
273
        loss_fg = -tf.reduce_mean(y_true * tf.math.log(y_pred + self.smooth), axis=1)
274
        if self.background_weight > 0:
275
            loss_bg = -tf.reduce_mean(
276
                (1 - y_true) * tf.math.log(1 - y_pred + self.smooth), axis=1
277
            )
278
            return (
279
                1 - self.background_weight
280
            ) * loss_fg + self.background_weight * loss_bg
281
        else:
282
            return loss_fg
283
284
    def get_config(self) -> dict:
285
        """Return the config dictionary for recreating this class."""
286
        config = super().get_config()
287
        config.update(
288
            binary=self.binary,
289
            background_weight=self.background_weight,
290
            smooth=self.smooth,
291
        )
292
        return config
293
294
295
class JaccardIndex(DiceScore):
296
    """
297
    Define Jaccard index.
298
299
    The formulation is:
300
    1. num = y_true * y_pred
301
    2. denom = y_true + y_pred - y_true * y_pred
302
    3. Jaccard index = num / denom
303
304
        0. w_fg + w_bg = 1
305
        1. let y_prod = y_true * y_pred and y_sum  = y_true + y_pred
306
        2. num = (w_fg * y_true * y_pred + w_bg * (1−y_true) * (1−y_pred))
307
               = ((w_fg+w_bg) * y_prod - w_bg * y_sum + w_bg)
308
               = (y_prod - w_bg * y_sum + w_bg)
309
        3. denom = (w_fg * (y_true + y_pred - y_true * y_pred)
310
                  + w_bg * (1−y_true + 1−y_pred - (1−y_true) * (1−y_pred)))
311
                 = w_fg * (y_sum - y_prod) + w_bg * (1-y_prod)
312
                 = (1-w_bg) * y_sum - y_prod + w_bg
313
        4. dice score = num / denom
314
    """
315
316
    def __init__(
317
        self,
318
        binary: bool = False,
319
        background_weight: float = 0.0,
320
        smooth_nr: float = EPS,
321
        smooth_dr: float = EPS,
322
        scales: Optional[List] = None,
323
        kernel: str = "gaussian",
324
        reduction: str = tf.keras.losses.Reduction.SUM,
325
        name: str = "JaccardIndex",
326
    ):
327
        """
328
        Init.
329
330
        :param binary: if True, project y_true, y_pred to 0 or 1.
331
        :param background_weight: weight for background, where y == 0.
332
        :param smooth_nr: small constant added to numerator in case of zero covariance.
333
        :param smooth_dr: small constant added to denominator in case of zero variance.
334
        :param scales: list of scalars or None, if None, do not apply any scaling.
335
        :param kernel: gaussian or cauchy.
336
        :param reduction: using SUM reduction over batch axis,
337
            this is for supporting multi-device training,
338
            and the loss will be divided by global batch size,
339
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
340
        :param name: str, name of the loss.
341
        """
342
        super().__init__(
343
            binary=binary,
344
            background_weight=background_weight,
345
            smooth_nr=smooth_nr,
346
            smooth_dr=smooth_dr,
347
            scales=scales,
348
            kernel=kernel,
349
            reduction=reduction,
350
            name=name,
351
        )
352
353
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
354
        """
355
        Return loss for a batch.
356
357
        :param y_true: shape = (batch, ...)
358
        :param y_pred: shape = (batch, ...)
359
        :return: shape = (batch,)
360
        """
361
        if self.binary:
362
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
363
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
364
365
        # (batch, ...) -> (batch, d)
366
        y_true = self.flatten(y_true)
367
        y_pred = self.flatten(y_pred)
368
369
        # for foreground class
370
        y_prod = tf.reduce_sum(y_true * y_pred, axis=1)
371
        y_sum = tf.reduce_sum(y_true + y_pred, axis=1)
372
373
        if self.background_weight > 0:
374
            # generalized
375
            vol = tf.reduce_sum(tf.ones_like(y_true), axis=1)
376
            numerator = (
377
                y_prod - self.background_weight * y_sum + self.background_weight * vol
378
            )
379
            denominator = (
380
                (1 - self.background_weight) * y_sum
381
                - y_prod
382
                + self.background_weight * vol
383
            )
384
        else:
385
            # foreground only
386
            numerator = y_prod
387
            denominator = y_sum - y_prod
388
389
        return (numerator + self.smooth_nr) / (denominator + self.smooth_dr)
390
391
392
@REGISTRY.register_loss(name="jaccard")
393
class JaccardLoss(NegativeLossMixin, JaccardIndex):
394
    """Revert the sign of JaccardIndex."""
395
396
397
def compute_centroid(mask: tf.Tensor, grid: tf.Tensor) -> tf.Tensor:
398
    """
399
    Calculate the centroid of the mask.
400
    :param mask: shape = (batch, dim1, dim2, dim3)
401
    :param grid: shape = (1, dim1, dim2, dim3, 3)
402
    :return: shape = (batch, 3), batch of vectors denoting
403
             location of centroids.
404
    """
405
    assert len(mask.shape) == 4
406
    assert len(grid.shape) == 5
407
    bool_mask = tf.expand_dims(
408
        tf.cast(mask >= 0.5, dtype=tf.float32), axis=4
409
    )  # (batch, dim1, dim2, dim3, 1)
410
    masked_grid = bool_mask * grid  # (batch, dim1, dim2, dim3, 3)
411
    numerator = tf.reduce_sum(masked_grid, axis=[1, 2, 3])  # (batch, 3)
412
    denominator = tf.reduce_sum(bool_mask, axis=[1, 2, 3])  # (batch, 1)
413
    return (numerator + EPS) / (denominator + EPS)  # (batch, 3)
414
415
416
def compute_centroid_distance(
417
    y_true: tf.Tensor, y_pred: tf.Tensor, grid: tf.Tensor
418
) -> tf.Tensor:
419
    """
420
    Calculate the L2-distance between two tensors' centroids.
421
    :param y_true: tensor, shape = (batch, dim1, dim2, dim3)
422
    :param y_pred: tensor, shape = (batch, dim1, dim2, dim3)
423
    :param grid: tensor, shape = (1, dim1, dim2, dim3, 3)
424
    :return: shape = (batch,)
425
    """
426
    centroid_1 = compute_centroid(mask=y_pred, grid=grid)  # (batch, 3)
427
    centroid_2 = compute_centroid(mask=y_true, grid=grid)  # (batch, 3)
428
    return tf.sqrt(tf.reduce_sum((centroid_1 - centroid_2) ** 2, axis=1))
429
430
431
def foreground_proportion(y: tf.Tensor) -> tf.Tensor:
432
    """
433
    Calculate the percentage of foreground vs background per 3d volume.
434
    :param y: shape = (batch, dim1, dim2, dim3), a 3D label tensor
435
    :return: shape = (batch,)
436
    """
437
    y = tf.cast(y >= 0.5, dtype=tf.float32)
438
    return tf.reduce_sum(y, axis=[1, 2, 3]) / tf.reduce_sum(
439
        tf.ones_like(y), axis=[1, 2, 3]
440
    )
441