@@ 46-65 (lines=20) @@ | ||
43 | def y_pred(self): |
|
44 | return np.ones(shape=self.shape) * 0.3 |
|
45 | ||
46 | @pytest.mark.parametrize( |
|
47 | "binary,neg_weight,scales,expected", |
|
48 | [ |
|
49 | (True, 0.0, None, 0.0), |
|
50 | (False, 0.0, None, 0.4), |
|
51 | (False, 0.2, None, 0.4 / 0.94), |
|
52 | (False, 0.2, [0, 0], 0.4 / 0.94), |
|
53 | (False, 0.2, [0, 1], 0.46030036), |
|
54 | ], |
|
55 | ) |
|
56 | def test_call(self, y_true, y_pred, binary, neg_weight, scales, expected): |
|
57 | expected = np.array([expected] * self.shape[0]) # call returns (batch, ) |
|
58 | got = label.DiceScore(binary=binary, neg_weight=neg_weight, scales=scales).call( |
|
59 | y_true=y_true, y_pred=y_pred |
|
60 | ) |
|
61 | assert is_equal_tf(got, expected) |
|
62 | got = label.DiceLoss(binary=binary, neg_weight=neg_weight, scales=scales).call( |
|
63 | y_true=y_true, y_pred=y_pred |
|
64 | ) |
|
65 | assert is_equal_tf(got, -expected) |
|
66 | ||
67 | def test_get_config(self): |
|
68 | got = label.DiceScore().get_config() |
|
@@ 132-150 (lines=19) @@ | ||
129 | def y_pred(self): |
|
130 | return np.ones(shape=self.shape) * 0.3 |
|
131 | ||
132 | @pytest.mark.parametrize( |
|
133 | "binary,scales,expected", |
|
134 | [ |
|
135 | (True, None, 0), |
|
136 | (False, None, 0.25), |
|
137 | (False, [0, 0], 0.25), |
|
138 | (False, [0, 1], 0.17484076), |
|
139 | ], |
|
140 | ) |
|
141 | def test_call(self, y_true, y_pred, binary, scales, expected): |
|
142 | expected = np.array([expected] * self.shape[0]) # call returns (batch, ) |
|
143 | got = label.JaccardIndex(binary=binary, scales=scales).call( |
|
144 | y_true=y_true, y_pred=y_pred |
|
145 | ) |
|
146 | assert is_equal_tf(got, expected) |
|
147 | got = label.JaccardLoss(binary=binary, scales=scales).call( |
|
148 | y_true=y_true, y_pred=y_pred |
|
149 | ) |
|
150 | assert is_equal_tf(got, -expected) |
|
151 | ||
152 | def test_get_config(self): |
|
153 | got = label.JaccardIndex().get_config() |