Code Duplication    Length = 42-42 lines in 2 locations

test/unit/test_loss_label.py 2 locations

@@ 282-323 (lines=42) @@
279
        expected = tf.constant(expected)
280
        assert is_equal_tf(got[0], expected)
281
282
    @pytest.mark.parametrize("binary", [True, False])
283
    @pytest.mark.parametrize("background_weight", [0.0, 0.1, 0.5, 1.0])
284
    @pytest.mark.parametrize("shape", [(1,), (10,), (100,), (2, 3), (2, 3, 4)])
285
    def test_exact_value(self, binary: bool, background_weight: float, shape: Tuple):
286
        """
287
        Test Jaccard index by comparing at ground truth values.
288
289
        :param binary: if project labels to binary values.
290
        :param background_weight: the weight of background class.
291
        :param shape: shape of input.
292
        """
293
        # init
294
        shape = (1,) + shape  # add batch axis
295
        foreground_weight = 1 - background_weight
296
        tf.random.set_seed(0)
297
        y_true = tf.random.uniform(shape=shape)
298
        y_pred = tf.random.uniform(shape=shape)
299
300
        # obtained value
301
        got = label.JaccardIndex(
302
            binary=binary,
303
            background_weight=background_weight,
304
        ).call(y_true=y_true, y_pred=y_pred)
305
306
        # expected value
307
        flatten = tf.keras.layers.Flatten()
308
        y_true = flatten(y_true)
309
        y_pred = flatten(y_pred)
310
        if binary:
311
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
312
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
313
314
        num = foreground_weight * tf.reduce_sum(
315
            y_true * y_pred, axis=1
316
        ) + background_weight * tf.reduce_sum((1 - y_true) * (1 - y_pred), axis=1)
317
        denom = foreground_weight * tf.reduce_sum(
318
            y_true + y_pred, axis=1
319
        ) + background_weight * tf.reduce_sum((1 - y_true) + (1 - y_pred), axis=1)
320
        denom = denom - num
321
        expected = (num + EPS) / (denom + EPS)
322
323
        assert is_equal_tf(got, expected)
324
325
    def test_get_config(self):
326
        got = label.JaccardIndex().get_config()
@@ 84-125 (lines=42) @@
81
        expected = tf.constant(expected)
82
        assert is_equal_tf(got[0], expected)
83
84
    @pytest.mark.parametrize("binary", [True, False])
85
    @pytest.mark.parametrize("background_weight", [0.0, 0.1, 0.5, 1.0])
86
    @pytest.mark.parametrize("shape", [(1,), (10,), (100,), (2, 3), (2, 3, 4)])
87
    def test_exact_value(self, binary: bool, background_weight: float, shape: Tuple):
88
        """
89
        Test dice score by comparing at ground truth values.
90
91
        :param binary: if project labels to binary values.
92
        :param background_weight: the weight of background class.
93
        :param shape: shape of input.
94
        """
95
        # init
96
        shape = (1,) + shape  # add batch axis
97
        foreground_weight = 1 - background_weight
98
        tf.random.set_seed(0)
99
        y_true = tf.random.uniform(shape=shape)
100
        y_pred = tf.random.uniform(shape=shape)
101
102
        # obtained value
103
        got = label.DiceScore(
104
            binary=binary,
105
            background_weight=background_weight,
106
        ).call(y_true=y_true, y_pred=y_pred)
107
108
        # expected value
109
        flatten = tf.keras.layers.Flatten()
110
        y_true = flatten(y_true)
111
        y_pred = flatten(y_pred)
112
        if binary:
113
            y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype)
114
            y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype)
115
116
        num = foreground_weight * tf.reduce_sum(
117
            y_true * y_pred, axis=1
118
        ) + background_weight * tf.reduce_sum((1 - y_true) * (1 - y_pred), axis=1)
119
        num *= 2
120
        denom = foreground_weight * tf.reduce_sum(
121
            y_true + y_pred, axis=1
122
        ) + background_weight * tf.reduce_sum((1 - y_true) + (1 - y_pred), axis=1)
123
        expected = (num + EPS) / (denom + EPS)
124
125
        assert is_equal_tf(got, expected)
126
127
    def test_get_config(self):
128
        got = label.DiceScore().get_config()