Completed
Push — main ( de3728...ca54a2 )
by Yunguan
19s queued 13s
created

deepreg.loss.label.CrossEntropy.call()   A

Complexity

Conditions 3

Size

Total Lines 26
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 14
dl 0
loc 26
rs 9.7
c 0
b 0
f 0
cc 3
nop 3
1
"""Provide different loss or metrics classes for labels."""
2
3
import tensorflow as tf
4
5
from deepreg.constant import EPS
6
from deepreg.loss.util import MultiScaleMixin, NegativeLossMixin
7
from deepreg.registry import REGISTRY
8
9
10
class SumSquaredDifference(tf.keras.losses.Loss):
11
    """
12
    Actually, mean of squared distance between y_true and y_pred.
13
14
    The inconsistent name was for convention.
15
16
    y_true and y_pred have to be at least 1d tensor, including batch axis.
17
    """
18
19
    def __init__(
20
        self,
21
        name: str = "SumSquaredDifference",
22
        **kwargs,
23
    ):
24
        """
25
        Init.
26
27
        :param name: name of the loss.
28
        :param kwargs: additional arguments.
29
        """
30
        super().__init__(name=name, **kwargs)
31
        self.flatten = tf.keras.layers.Flatten()
32
33
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
34
        """
35
        Return mean squared different for a batch.
36
37
        :param y_true: shape = (batch, ...)
38
        :param y_pred: shape = (batch, ...)
39
        :return: shape = (batch,)
40
        """
41
        loss = tf.math.squared_difference(y_true, y_pred)
42
        loss = self.flatten(loss)
43
        return tf.reduce_mean(loss, axis=1)
44
45
46
@REGISTRY.register_loss(name="ssd")
47
class SumSquaredDifferenceLoss(MultiScaleMixin, SumSquaredDifference):
48
    """Define loss with multi-scaling options."""
49
50
51
class DiceScore(tf.keras.losses.Loss):
52
    """
53
    Define dice score.
54
55
    The formulation is:
56
57
        0. w_fg + w_bg = 1
58
        1. let y_prod = y_true * y_pred and y_sum  = y_true + y_pred
59
        2. num = 2 *  (w_fg * y_true * y_pred + w_bg * (1−y_true) * (1−y_pred))
60
               = 2 *  ((w_fg+w_bg) * y_prod - w_bg * y_sum + w_bg)
61
               = 2 *  (y_prod - w_bg * y_sum + w_bg)
62
        3. denom = (w_fg * (y_true + y_pred) + w_bg * (1−y_true + 1−y_pred))
63
                 = (w_fg-w_bg) * y_sum + 2 * w_bg
64
                 = (1-2*w_bg) * y_sum + 2 * w_bg
65
        4. dice score = num / denom
66
67
    where num and denom are summed over all axes except the batch axis.
68
69
    Reference:
70
        Sudre, Carole H., et al. "Generalised dice overlap as a deep learning loss
71
        function for highly unbalanced segmentations." Deep learning in medical image
72
        analysis and multimodal learning for clinical decision support.
73
        Springer, Cham, 2017. 240-248.
74
    """
75
76
    def __init__(
77
        self,
78
        binary: bool = False,
79
        background_weight: float = 0.0,
80
        smooth_nr: float = EPS,
81
        smooth_dr: float = EPS,
82
        name: str = "DiceScore",
83
        **kwargs,
84
    ):
85
        """
86
        Init.
87
88
        :param binary: if True, project y_true, y_pred to 0 or 1.
89
        :param background_weight: weight for background, where y == 0.
90
        :param smooth_nr: small constant added to numerator in case of zero covariance.
91
        :param smooth_dr: small constant added to denominator in case of zero variance.
92
        :param name: name of the loss.
93
        :param kwargs: additional arguments.
94
        """
95
        super().__init__(name=name, **kwargs)
96
        if background_weight < 0 or background_weight > 1:
97
            raise ValueError(
98
                "The background weight for Dice Score must be "
99
                f"within [0, 1], got {background_weight}."
100
            )
101
102
        self.binary = binary
103
        self.background_weight = background_weight
104
        self.smooth_nr = smooth_nr
105
        self.smooth_dr = smooth_dr
106
        self.flatten = tf.keras.layers.Flatten()
107
108
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
109
        """
110
        Return loss for a batch.
111
112
        :param y_true: shape = (batch, ...)
113
        :param y_pred: shape = (batch, ...)
114
        :return: shape = (batch,)
115
        """
116
        if self.binary:
117
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
118
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
119
120
        # (batch, ...) -> (batch, d)
121
        y_true = self.flatten(y_true)
122
        y_pred = self.flatten(y_pred)
123
124
        # for foreground class
125
        y_prod = tf.reduce_sum(y_true * y_pred, axis=1)
126
        y_sum = tf.reduce_sum(y_true + y_pred, axis=1)
127
128
        if self.background_weight > 0:
129
            # generalized
130
            vol = tf.reduce_sum(tf.ones_like(y_true), axis=1)
131
            numerator = 2 * (
132
                y_prod - self.background_weight * y_sum + self.background_weight * vol
133
            )
134
            denominator = (
135
                1 - 2 * self.background_weight
136
            ) * y_sum + 2 * self.background_weight * vol
137
        else:
138
            # foreground only
139
            numerator = 2 * y_prod
140
            denominator = y_sum
141
142
        return (numerator + self.smooth_nr) / (denominator + self.smooth_dr)
143
144
    def get_config(self) -> dict:
145
        """Return the config dictionary for recreating this class."""
146
        config = super().get_config()
147
        config.update(
148
            binary=self.binary,
149
            background_weight=self.background_weight,
150
            smooth_nr=self.smooth_nr,
151
            smooth_dr=self.smooth_dr,
152
        )
153
        return config
154
155
156
@REGISTRY.register_loss(name="dice")
157
class DiceLoss(NegativeLossMixin, MultiScaleMixin, DiceScore):
158
    """Revert the sign of DiceScore and support multi-scaling options."""
159
160
161
class CrossEntropy(tf.keras.losses.Loss):
162
    """
163
    Define weighted cross-entropy.
164
165
    The formulation is:
166
        loss = − w_fg * y_true log(y_pred) - w_bg * (1−y_true) log(1−y_pred)
167
    """
168
169
    def __init__(
170
        self,
171
        binary: bool = False,
172
        background_weight: float = 0.0,
173
        smooth: float = EPS,
174
        name: str = "CrossEntropy",
175
        **kwargs,
176
    ):
177
        """
178
        Init.
179
180
        :param binary: if True, project y_true, y_pred to 0 or 1
181
        :param background_weight: weight for background, where y == 0.
182
        :param smooth: smooth constant for log.
183
        :param name: name of the loss.
184
        :param kwargs: additional arguments.
185
        """
186
        super().__init__(name=name, **kwargs)
187
        if background_weight < 0 or background_weight > 1:
188
            raise ValueError(
189
                "The background weight for Cross Entropy must be "
190
                f"within [0, 1], got {background_weight}."
191
            )
192
        self.binary = binary
193
        self.background_weight = background_weight
194
        self.smooth = smooth
195
        self.flatten = tf.keras.layers.Flatten()
196
197
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
198
        """
199
        Return loss for a batch.
200
201
        :param y_true: shape = (batch, ...)
202
        :param y_pred: shape = (batch, ...)
203
        :return: shape = (batch,)
204
        """
205
        if self.binary:
206
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
207
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
208
209
        # (batch, ...) -> (batch, d)
210
        y_true = self.flatten(y_true)
211
        y_pred = self.flatten(y_pred)
212
213
        loss_fg = -tf.reduce_mean(y_true * tf.math.log(y_pred + self.smooth), axis=1)
214
        if self.background_weight > 0:
215
            loss_bg = -tf.reduce_mean(
216
                (1 - y_true) * tf.math.log(1 - y_pred + self.smooth), axis=1
217
            )
218
            return (
219
                1 - self.background_weight
220
            ) * loss_fg + self.background_weight * loss_bg
221
        else:
222
            return loss_fg
223
224
    def get_config(self) -> dict:
225
        """Return the config dictionary for recreating this class."""
226
        config = super().get_config()
227
        config.update(
228
            binary=self.binary,
229
            background_weight=self.background_weight,
230
            smooth=self.smooth,
231
        )
232
        return config
233
234
235
@REGISTRY.register_loss(name="cross-entropy")
236
class CrossEntropyLoss(MultiScaleMixin, CrossEntropy):
237
    """Define loss with multi-scaling options."""
238
239
240
class JaccardIndex(DiceScore):
241
    """
242
    Define Jaccard index.
243
244
    The formulation is:
245
    1. num = y_true * y_pred
246
    2. denom = y_true + y_pred - y_true * y_pred
247
    3. Jaccard index = num / denom
248
249
        0. w_fg + w_bg = 1
250
        1. let y_prod = y_true * y_pred and y_sum  = y_true + y_pred
251
        2. num = (w_fg * y_true * y_pred + w_bg * (1−y_true) * (1−y_pred))
252
               = ((w_fg+w_bg) * y_prod - w_bg * y_sum + w_bg)
253
               = (y_prod - w_bg * y_sum + w_bg)
254
        3. denom = (w_fg * (y_true + y_pred - y_true * y_pred)
255
                  + w_bg * (1−y_true + 1−y_pred - (1−y_true) * (1−y_pred)))
256
                 = w_fg * (y_sum - y_prod) + w_bg * (1-y_prod)
257
                 = (1-w_bg) * y_sum - y_prod + w_bg
258
        4. dice score = num / denom
259
    """
260
261
    def __init__(
262
        self,
263
        binary: bool = False,
264
        background_weight: float = 0.0,
265
        smooth_nr: float = EPS,
266
        smooth_dr: float = EPS,
267
        name: str = "JaccardIndex",
268
        **kwargs,
269
    ):
270
        """
271
        Init.
272
273
        :param binary: if True, project y_true, y_pred to 0 or 1.
274
        :param background_weight: weight for background, where y == 0.
275
        :param smooth_nr: small constant added to numerator in case of zero covariance.
276
        :param smooth_dr: small constant added to denominator in case of zero variance.
277
        :param name: name of the loss.
278
        :param kwargs: additional arguments.
279
        """
280
        super().__init__(
281
            binary=binary,
282
            background_weight=background_weight,
283
            smooth_nr=smooth_nr,
284
            smooth_dr=smooth_dr,
285
            name=name,
286
            **kwargs,
287
        )
288
289
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
290
        """
291
        Return loss for a batch.
292
293
        :param y_true: shape = (batch, ...)
294
        :param y_pred: shape = (batch, ...)
295
        :return: shape = (batch,)
296
        """
297
        if self.binary:
298
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
299
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
300
301
        # (batch, ...) -> (batch, d)
302
        y_true = self.flatten(y_true)
303
        y_pred = self.flatten(y_pred)
304
305
        # for foreground class
306
        y_prod = tf.reduce_sum(y_true * y_pred, axis=1)
307
        y_sum = tf.reduce_sum(y_true + y_pred, axis=1)
308
309
        if self.background_weight > 0:
310
            # generalized
311
            vol = tf.reduce_sum(tf.ones_like(y_true), axis=1)
312
            numerator = (
313
                y_prod - self.background_weight * y_sum + self.background_weight * vol
314
            )
315
            denominator = (
316
                (1 - self.background_weight) * y_sum
317
                - y_prod
318
                + self.background_weight * vol
319
            )
320
        else:
321
            # foreground only
322
            numerator = y_prod
323
            denominator = y_sum - y_prod
324
325
        return (numerator + self.smooth_nr) / (denominator + self.smooth_dr)
326
327
328
@REGISTRY.register_loss(name="jaccard")
329
class JaccardLoss(NegativeLossMixin, MultiScaleMixin, JaccardIndex):
330
    """Revert the sign of JaccardIndex."""
331
332
333
def compute_centroid(mask: tf.Tensor, grid: tf.Tensor) -> tf.Tensor:
334
    """
335
    Calculate the centroid of the mask.
336
    :param mask: shape = (batch, dim1, dim2, dim3)
337
    :param grid: shape = (1, dim1, dim2, dim3, 3)
338
    :return: shape = (batch, 3), batch of vectors denoting
339
             location of centroids.
340
    """
341
    assert len(mask.shape) == 4
342
    assert len(grid.shape) == 5
343
    bool_mask = tf.expand_dims(
344
        tf.cast(mask >= 0.5, dtype=tf.float32), axis=4
345
    )  # (batch, dim1, dim2, dim3, 1)
346
    masked_grid = bool_mask * grid  # (batch, dim1, dim2, dim3, 3)
347
    numerator = tf.reduce_sum(masked_grid, axis=[1, 2, 3])  # (batch, 3)
348
    denominator = tf.reduce_sum(bool_mask, axis=[1, 2, 3])  # (batch, 1)
349
    return (numerator + EPS) / (denominator + EPS)  # (batch, 3)
350
351
352
def compute_centroid_distance(
353
    y_true: tf.Tensor, y_pred: tf.Tensor, grid: tf.Tensor
354
) -> tf.Tensor:
355
    """
356
    Calculate the L2-distance between two tensors' centroids.
357
    :param y_true: tensor, shape = (batch, dim1, dim2, dim3)
358
    :param y_pred: tensor, shape = (batch, dim1, dim2, dim3)
359
    :param grid: tensor, shape = (1, dim1, dim2, dim3, 3)
360
    :return: shape = (batch,)
361
    """
362
    centroid_1 = compute_centroid(mask=y_pred, grid=grid)  # (batch, 3)
363
    centroid_2 = compute_centroid(mask=y_true, grid=grid)  # (batch, 3)
364
    return tf.sqrt(tf.reduce_sum((centroid_1 - centroid_2) ** 2, axis=1))
365
366
367
def foreground_proportion(y: tf.Tensor) -> tf.Tensor:
368
    """
369
    Calculate the percentage of foreground vs background per 3d volume.
370
    :param y: shape = (batch, dim1, dim2, dim3), a 3D label tensor
371
    :return: shape = (batch,)
372
    """
373
    y = tf.cast(y >= 0.5, dtype=tf.float32)
374
    return tf.reduce_sum(y, axis=[1, 2, 3]) / tf.reduce_sum(
375
        tf.ones_like(y), axis=[1, 2, 3]
376
    )
377