|
@@ 46-65 (lines=20) @@
|
| 43 |
|
def y_pred(self): |
| 44 |
|
return np.ones(shape=self.shape) * 0.3 |
| 45 |
|
|
| 46 |
|
@pytest.mark.parametrize( |
| 47 |
|
"binary,neg_weight,scales,expected", |
| 48 |
|
[ |
| 49 |
|
(True, 0.0, None, 0.0), |
| 50 |
|
(False, 0.0, None, 0.4), |
| 51 |
|
(False, 0.2, None, 0.4 / 0.94), |
| 52 |
|
(False, 0.2, [0, 0], 0.4 / 0.94), |
| 53 |
|
(False, 0.2, [0, 1], 0.46030036), |
| 54 |
|
], |
| 55 |
|
) |
| 56 |
|
def test_call(self, y_true, y_pred, binary, neg_weight, scales, expected): |
| 57 |
|
expected = np.array([expected] * self.shape[0]) # call returns (batch, ) |
| 58 |
|
got = label.DiceScore(binary=binary, neg_weight=neg_weight, scales=scales).call( |
| 59 |
|
y_true=y_true, y_pred=y_pred |
| 60 |
|
) |
| 61 |
|
assert is_equal_tf(got, expected) |
| 62 |
|
got = label.DiceLoss(binary=binary, neg_weight=neg_weight, scales=scales).call( |
| 63 |
|
y_true=y_true, y_pred=y_pred |
| 64 |
|
) |
| 65 |
|
assert is_equal_tf(got, -expected) |
| 66 |
|
|
| 67 |
|
def test_get_config(self): |
| 68 |
|
got = label.DiceScore().get_config() |
|
@@ 132-150 (lines=19) @@
|
| 129 |
|
def y_pred(self): |
| 130 |
|
return np.ones(shape=self.shape) * 0.3 |
| 131 |
|
|
| 132 |
|
@pytest.mark.parametrize( |
| 133 |
|
"binary,scales,expected", |
| 134 |
|
[ |
| 135 |
|
(True, None, 0), |
| 136 |
|
(False, None, 0.25), |
| 137 |
|
(False, [0, 0], 0.25), |
| 138 |
|
(False, [0, 1], 0.17484076), |
| 139 |
|
], |
| 140 |
|
) |
| 141 |
|
def test_call(self, y_true, y_pred, binary, scales, expected): |
| 142 |
|
expected = np.array([expected] * self.shape[0]) # call returns (batch, ) |
| 143 |
|
got = label.JaccardIndex(binary=binary, scales=scales).call( |
| 144 |
|
y_true=y_true, y_pred=y_pred |
| 145 |
|
) |
| 146 |
|
assert is_equal_tf(got, expected) |
| 147 |
|
got = label.JaccardLoss(binary=binary, scales=scales).call( |
| 148 |
|
y_true=y_true, y_pred=y_pred |
| 149 |
|
) |
| 150 |
|
assert is_equal_tf(got, -expected) |
| 151 |
|
|
| 152 |
|
def test_get_config(self): |
| 153 |
|
got = label.JaccardIndex().get_config() |