Completed
Push — main ( bd3202...72b597 )
by Yunguan
27s queued 13s
created

deepreg.loss.label   A

Complexity

Total Complexity 22

Size/Duplication

Total Lines 350
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 22
eloc 152
dl 0
loc 350
rs 10
c 0
b 0
f 0

13 Methods

Rating   Name   Duplication   Size   Complexity  
A MultiScaleLoss.__init__() 0 20 1
A MultiScaleLoss.get_config() 0 6 1
A MultiScaleLoss._call() 0 9 1
A JaccardIndex.__init__() 0 20 1
A DiceScore._call() 0 22 2
A CrossEntropy.get_config() 0 6 1
A DiceScore.get_config() 0 6 1
A JaccardIndex.get_config() 0 5 1
A JaccardIndex._call() 0 20 2
A MultiScaleLoss.call() 0 35 4
A CrossEntropy._call() 0 19 2
A DiceScore.__init__() 0 24 1
A CrossEntropy.__init__() 0 24 1

3 Functions

Rating   Name   Duplication   Size   Complexity  
A compute_centroid_distance() 0 13 1
A compute_centroid() 0 19 1
A foreground_proportion() 0 9 1
1
"""Provide different loss or metrics classes for labels."""
2
3
from typing import List, Optional
4
5
import tensorflow as tf
6
7
from deepreg.loss.util import NegativeLossMixin, cauchy_kernel1d
8
from deepreg.loss.util import gaussian_kernel1d_sigma as gaussian_kernel1d
9
from deepreg.loss.util import separable_filter
10
from deepreg.registry import REGISTRY
11
12
EPS = tf.keras.backend.epsilon()
13
14
15
class MultiScaleLoss(tf.keras.losses.Loss):
16
    """
17
    Base class for multi-scale loss.
18
19
    It applies the loss at different scales (gaussian or cauchy smoothing).
20
    It is assumed that loss values are between 0 and 1.
21
    """
22
23
    kernel_fn_dict = dict(gaussian=gaussian_kernel1d, cauchy=cauchy_kernel1d)
24
25
    def __init__(
26
        self,
27
        scales: Optional[List] = None,
28
        kernel: str = "gaussian",
29
        reduction: str = tf.keras.losses.Reduction.SUM,
30
        name: str = "MultiScaleLoss",
31
    ):
32
        """
33
        Init.
34
35
        :param scales: list of scalars or None, if None, do not apply any scaling.
36
        :param kernel: gaussian or cauchy.
37
        :param reduction: using SUM reduction over batch axis,
38
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
39
        :param name: str, name of the loss.
40
        """
41
        super().__init__(reduction=reduction, name=name)
42
        assert kernel in ["gaussian", "cauchy"]
43
        self.scales = scales
44
        self.kernel = kernel
45
46
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
47
        """
48
        Use _call to calculate loss at different scales.
49
50
        :param y_true: ground-truth tensor.
51
        :param y_pred: predicted tensor.
52
        :return: multi-scale loss.
53
        """
54
        if self.scales is None:
55
            return self._call(y_true=y_true, y_pred=y_pred)
56
        kernel_fn = self.kernel_fn_dict[self.kernel]
57
        losses = []
58
        for s in self.scales:
59
            if s == 0:
60
                # no smoothing
61
                losses.append(
62
                    self._call(
63
                        y_true=y_true,
64
                        y_pred=y_pred,
65
                    )
66
                )
67
            else:
68
                losses.append(
69
                    self._call(
70
                        y_true=separable_filter(
71
                            tf.expand_dims(y_true, axis=4), kernel_fn(s)
72
                        )[..., 0],
73
                        y_pred=separable_filter(
74
                            tf.expand_dims(y_pred, axis=4), kernel_fn(s)
75
                        )[..., 0],
76
                    )
77
                )
78
        loss = tf.add_n(losses)
79
        loss = loss / len(self.scales)
80
        return loss
81
82
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
83
        """
84
        Return loss for a batch.
85
86
        :param y_true: ground-truth tensor.
87
        :param y_pred: predicted tensor.
88
        :return: negated loss.
89
        """
90
        raise NotImplementedError
91
92
    def get_config(self) -> dict:
93
        """Return the config dictionary for recreating this class."""
94
        config = super().get_config()
95
        config["scales"] = self.scales
96
        config["kernel"] = self.kernel
97
        return config
98
99
100
class DiceScore(MultiScaleLoss):
101
    """
102
    Define dice score.
103
104
    The formulation is:
105
        0. pos_w + neg_w = 1
106
        1. let y_prod = y_true * y_pred and y_sum  = y_true + y_pred
107
        2. num = 2 *  (pos_w * y_true * y_pred + neg_w * (1−y_true) * (1−y_pred))
108
               = 2 *  ((pos_w+neg_w) * y_prod - neg_w * y_sum + neg_w)
109
               = 2 *  (y_prod - neg_w * y_sum + neg_w)
110
        3. denom = (pos_w * (y_true + y_pred) + neg_w * (1−y_true + 1−y_pred))
111
                 = (pos_w-neg_w) * y_sum + 2 * neg_w
112
                 = (1-2*neg_w) * y_sum + 2 * neg_w
113
        4. dice score = num / denom
114
115
    where num and denom are summed over all axes except the batch axis.
116
    """
117
118
    def __init__(
119
        self,
120
        binary: bool = False,
121
        neg_weight: float = 0.0,
122
        scales: Optional[List] = None,
123
        kernel: str = "gaussian",
124
        reduction: str = tf.keras.losses.Reduction.SUM,
125
        name: str = "DiceScore",
126
    ):
127
        """
128
        Init.
129
130
        :param binary: if True, project y_true, y_pred to 0 or 1.
131
        :param neg_weight: weight for negative class.
132
        :param scales: list of scalars or None, if None, do not apply any scaling.
133
        :param kernel: gaussian or cauchy.
134
        :param reduction: using SUM reduction over batch axis,
135
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
136
        :param name: str, name of the loss.
137
        """
138
        super().__init__(scales=scales, kernel=kernel, reduction=reduction, name=name)
139
        assert 0 <= neg_weight <= 1
140
        self.binary = binary
141
        self.neg_weight = neg_weight
142
143
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
144
        """
145
        Return loss for a batch.
146
147
        :param y_true: shape = (batch, ...)
148
        :param y_pred: shape = (batch, ...)
149
        :return: shape = (batch,)
150
        """
151
        if self.binary:
152
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
153
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
154
155
        # (batch, ...) -> (batch, d)
156
        y_true = tf.keras.layers.Flatten()(y_true)
157
        y_pred = tf.keras.layers.Flatten()(y_pred)
158
159
        y_prod = tf.reduce_mean(y_true * y_pred, axis=1)
160
        y_sum = tf.reduce_mean(y_true, axis=1) + tf.reduce_mean(y_pred, axis=1)
161
162
        numerator = 2 * (y_prod - self.neg_weight * y_sum + self.neg_weight)
163
        denominator = (1 - 2 * self.neg_weight) * y_sum + 2 * self.neg_weight
164
        return (numerator + EPS) / (denominator + EPS)
165
166
    def get_config(self) -> dict:
167
        """Return the config dictionary for recreating this class."""
168
        config = super().get_config()
169
        config["binary"] = self.binary
170
        config["neg_weight"] = self.neg_weight
171
        return config
172
173
174
@REGISTRY.register_loss(name="dice")
175
class DiceLoss(NegativeLossMixin, DiceScore):
176
    """Revert the sign of DiceScore."""
177
178
179
@REGISTRY.register_loss(name="cross-entropy")
180
class CrossEntropy(MultiScaleLoss):
181
    """
182
    Define weighted cross-entropy.
183
184
    The formulation is:
185
        loss = − pos_w * y_true log(y_pred) - neg_w * (1−y_true) log(1−y_pred)
186
    """
187
188
    def __init__(
189
        self,
190
        binary: bool = False,
191
        neg_weight: float = 0.0,
192
        scales: Optional[List] = None,
193
        kernel: str = "gaussian",
194
        reduction: str = tf.keras.losses.Reduction.SUM,
195
        name: str = "CrossEntropy",
196
    ):
197
        """
198
        Init.
199
200
        :param binary: if True, project y_true, y_pred to 0 or 1
201
        :param neg_weight: weight for negative class
202
        :param scales: list of scalars or None, if None, do not apply any scaling.
203
        :param kernel: gaussian or cauchy.
204
        :param reduction: using SUM reduction over batch axis,
205
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
206
        :param name: str, name of the loss.
207
        """
208
        super().__init__(scales=scales, kernel=kernel, reduction=reduction, name=name)
209
        assert 0 <= neg_weight <= 1
210
        self.binary = binary
211
        self.neg_weight = neg_weight
212
213
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
214
        """
215
        Return loss for a batch.
216
217
        :param y_true: shape = (batch, ...)
218
        :param y_pred: shape = (batch, ...)
219
        :return: shape = (batch,)
220
        """
221
        if self.binary:
222
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
223
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
224
225
        # (batch, ...) -> (batch, d)
226
        y_true = tf.keras.layers.Flatten()(y_true)
227
        y_pred = tf.keras.layers.Flatten()(y_pred)
228
229
        loss_pos = tf.reduce_mean(y_true * tf.math.log(y_pred + EPS), axis=1)
230
        loss_neg = tf.reduce_mean((1 - y_true) * tf.math.log(1 - y_pred + EPS), axis=1)
231
        return -(1 - self.neg_weight) * loss_pos - self.neg_weight * loss_neg
232
233
    def get_config(self) -> dict:
234
        """Return the config dictionary for recreating this class."""
235
        config = super().get_config()
236
        config["binary"] = self.binary
237
        config["neg_weight"] = self.neg_weight
238
        return config
239
240
241
class JaccardIndex(MultiScaleLoss):
242
    """
243
    Define Jaccard index.
244
245
    The formulation is:
246
    1. num = y_true * y_pred
247
    2. denom = y_true + y_pred - y_true * y_pred
248
    3. Jaccard index = num / denom
249
    """
250
251
    def __init__(
252
        self,
253
        binary: bool = False,
254
        scales: Optional[List] = None,
255
        kernel: str = "gaussian",
256
        reduction: str = tf.keras.losses.Reduction.SUM,
257
        name: str = "JaccardIndex",
258
    ):
259
        """
260
        Init.
261
262
        :param binary: if True, project y_true, y_pred to 0 or 1.
263
        :param scales: list of scalars or None, if None, do not apply any scaling.
264
        :param kernel: gaussian or cauchy.
265
        :param reduction: using SUM reduction over batch axis,
266
            calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
267
        :param name: str, name of the loss.
268
        """
269
        super().__init__(scales=scales, kernel=kernel, reduction=reduction, name=name)
270
        self.binary = binary
271
272
    def _call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
273
        """
274
        Return loss for a batch.
275
276
        :param y_true: shape = (batch, ...)
277
        :param y_pred: shape = (batch, ...)
278
        :return: shape = (batch,)
279
        """
280
        if self.binary:
281
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
282
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
283
284
        # (batch, ...) -> (batch, d)
285
        y_true = tf.keras.layers.Flatten()(y_true)
286
        y_pred = tf.keras.layers.Flatten()(y_pred)
287
288
        y_prod = tf.reduce_mean(y_true * y_pred, axis=1)
289
        y_sum = tf.reduce_mean(y_true, axis=1) + tf.reduce_mean(y_pred, axis=1)
290
291
        return (y_prod + EPS) / (y_sum - y_prod + EPS)
292
293
    def get_config(self) -> dict:
294
        """Return the config dictionary for recreating this class."""
295
        config = super().get_config()
296
        config["binary"] = self.binary
297
        return config
298
299
300
@REGISTRY.register_loss(name="jaccard")
301
class JaccardLoss(NegativeLossMixin, JaccardIndex):
302
    """Revert the sign of JaccardIndex."""
303
304
305
def compute_centroid(mask: tf.Tensor, grid: tf.Tensor) -> tf.Tensor:
306
    """
307
    Calculate the centroid of the mask.
308
    :param mask: shape = (batch, dim1, dim2, dim3)
309
    :param grid: shape = (dim1, dim2, dim3, 3)
310
    :return: shape = (batch, 3), batch of vectors denoting
311
             location of centroids.
312
    """
313
    assert len(mask.shape) == 4
314
    assert len(grid.shape) == 4
315
    bool_mask = tf.expand_dims(
316
        tf.cast(mask >= 0.5, dtype=tf.float32), axis=4
317
    )  # (batch, dim1, dim2, dim3, 1)
318
    masked_grid = bool_mask * tf.expand_dims(
319
        grid, axis=0
320
    )  # (batch, dim1, dim2, dim3, 3)
321
    numerator = tf.reduce_sum(masked_grid, axis=[1, 2, 3])  # (batch, 3)
322
    denominator = tf.reduce_sum(bool_mask, axis=[1, 2, 3])  # (batch, 1)
323
    return (numerator + EPS) / (denominator + EPS)  # (batch, 3)
324
325
326
def compute_centroid_distance(
327
    y_true: tf.Tensor, y_pred: tf.Tensor, grid: tf.Tensor
328
) -> tf.Tensor:
329
    """
330
    Calculate the L2-distance between two tensors' centroids.
331
    :param y_true: tensor, shape = (batch, dim1, dim2, dim3)
332
    :param y_pred: tensor, shape = (batch, dim1, dim2, dim3)
333
    :param grid: tensor, shape = (dim1, dim2, dim3, 3)
334
    :return: shape = (batch,)
335
    """
336
    centroid_1 = compute_centroid(mask=y_pred, grid=grid)  # (batch, 3)
337
    centroid_2 = compute_centroid(mask=y_true, grid=grid)  # (batch, 3)
338
    return tf.sqrt(tf.reduce_sum((centroid_1 - centroid_2) ** 2, axis=1))
339
340
341
def foreground_proportion(y: tf.Tensor) -> tf.Tensor:
342
    """
343
    Calculate the percentage of foreground vs background per 3d volume.
344
    :param y: shape = (batch, dim1, dim2, dim3), a 3D label tensor
345
    :return: shape = (batch,)
346
    """
347
    y = tf.cast(y >= 0.5, dtype=tf.float32)
348
    return tf.reduce_sum(y, axis=[1, 2, 3]) / tf.reduce_sum(
349
        tf.ones_like(y), axis=[1, 2, 3]
350
    )
351