Code Duplication    Length = 19-20 lines in 2 locations

test/unit/test_loss_label.py 2 locations

@@ 46-65 (lines=20) @@
43
    def y_pred(self):
44
        return np.ones(shape=self.shape) * 0.3
45
46
    @pytest.mark.parametrize(
47
        "binary,neg_weight,scales,expected",
48
        [
49
            (True, 0.0, None, 0.0),
50
            (False, 0.0, None, 0.4),
51
            (False, 0.2, None, 0.4 / 0.94),
52
            (False, 0.2, [0, 0], 0.4 / 0.94),
53
            (False, 0.2, [0, 1], 0.46030036),
54
        ],
55
    )
56
    def test_call(self, y_true, y_pred, binary, neg_weight, scales, expected):
57
        expected = np.array([expected] * self.shape[0])  # call returns (batch, )
58
        got = label.DiceScore(binary=binary, neg_weight=neg_weight, scales=scales).call(
59
            y_true=y_true, y_pred=y_pred
60
        )
61
        assert is_equal_tf(got, expected)
62
        got = label.DiceLoss(binary=binary, neg_weight=neg_weight, scales=scales).call(
63
            y_true=y_true, y_pred=y_pred
64
        )
65
        assert is_equal_tf(got, -expected)
66
67
    def test_get_config(self):
68
        got = label.DiceScore().get_config()
@@ 132-150 (lines=19) @@
129
    def y_pred(self):
130
        return np.ones(shape=self.shape) * 0.3
131
132
    @pytest.mark.parametrize(
133
        "binary,scales,expected",
134
        [
135
            (True, None, 0),
136
            (False, None, 0.25),
137
            (False, [0, 0], 0.25),
138
            (False, [0, 1], 0.17484076),
139
        ],
140
    )
141
    def test_call(self, y_true, y_pred, binary, scales, expected):
142
        expected = np.array([expected] * self.shape[0])  # call returns (batch, )
143
        got = label.JaccardIndex(binary=binary, scales=scales).call(
144
            y_true=y_true, y_pred=y_pred
145
        )
146
        assert is_equal_tf(got, expected)
147
        got = label.JaccardLoss(binary=binary, scales=scales).call(
148
            y_true=y_true, y_pred=y_pred
149
        )
150
        assert is_equal_tf(got, -expected)
151
152
    def test_get_config(self):
153
        got = label.JaccardIndex().get_config()