|
@@ 282-323 (lines=42) @@
|
| 279 |
|
expected = tf.constant(expected) |
| 280 |
|
assert is_equal_tf(got[0], expected) |
| 281 |
|
|
| 282 |
|
@pytest.mark.parametrize("binary", [True, False]) |
| 283 |
|
@pytest.mark.parametrize("background_weight", [0.0, 0.1, 0.5, 1.0]) |
| 284 |
|
@pytest.mark.parametrize("shape", [(1,), (10,), (100,), (2, 3), (2, 3, 4)]) |
| 285 |
|
def test_exact_value(self, binary: bool, background_weight: float, shape: Tuple): |
| 286 |
|
""" |
| 287 |
|
Test Jaccard index by comparing at ground truth values. |
| 288 |
|
|
| 289 |
|
:param binary: if project labels to binary values. |
| 290 |
|
:param background_weight: the weight of background class. |
| 291 |
|
:param shape: shape of input. |
| 292 |
|
""" |
| 293 |
|
# init |
| 294 |
|
shape = (1,) + shape # add batch axis |
| 295 |
|
foreground_weight = 1 - background_weight |
| 296 |
|
tf.random.set_seed(0) |
| 297 |
|
y_true = tf.random.uniform(shape=shape) |
| 298 |
|
y_pred = tf.random.uniform(shape=shape) |
| 299 |
|
|
| 300 |
|
# obtained value |
| 301 |
|
got = label.JaccardIndex( |
| 302 |
|
binary=binary, |
| 303 |
|
background_weight=background_weight, |
| 304 |
|
).call(y_true=y_true, y_pred=y_pred) |
| 305 |
|
|
| 306 |
|
# expected value |
| 307 |
|
flatten = tf.keras.layers.Flatten() |
| 308 |
|
y_true = flatten(y_true) |
| 309 |
|
y_pred = flatten(y_pred) |
| 310 |
|
if binary: |
| 311 |
|
y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype) |
| 312 |
|
y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype) |
| 313 |
|
|
| 314 |
|
num = foreground_weight * tf.reduce_sum( |
| 315 |
|
y_true * y_pred, axis=1 |
| 316 |
|
) + background_weight * tf.reduce_sum((1 - y_true) * (1 - y_pred), axis=1) |
| 317 |
|
denom = foreground_weight * tf.reduce_sum( |
| 318 |
|
y_true + y_pred, axis=1 |
| 319 |
|
) + background_weight * tf.reduce_sum((1 - y_true) + (1 - y_pred), axis=1) |
| 320 |
|
denom = denom - num |
| 321 |
|
expected = (num + EPS) / (denom + EPS) |
| 322 |
|
|
| 323 |
|
assert is_equal_tf(got, expected) |
| 324 |
|
|
| 325 |
|
def test_get_config(self): |
| 326 |
|
got = label.JaccardIndex().get_config() |
|
@@ 79-120 (lines=42) @@
|
| 76 |
|
expected = tf.constant(expected) |
| 77 |
|
assert is_equal_tf(got[0], expected) |
| 78 |
|
|
| 79 |
|
@pytest.mark.parametrize("binary", [True, False]) |
| 80 |
|
@pytest.mark.parametrize("background_weight", [0.0, 0.1, 0.5, 1.0]) |
| 81 |
|
@pytest.mark.parametrize("shape", [(1,), (10,), (100,), (2, 3), (2, 3, 4)]) |
| 82 |
|
def test_exact_value(self, binary: bool, background_weight: float, shape: Tuple): |
| 83 |
|
""" |
| 84 |
|
Test dice score by comparing at ground truth values. |
| 85 |
|
|
| 86 |
|
:param binary: if project labels to binary values. |
| 87 |
|
:param background_weight: the weight of background class. |
| 88 |
|
:param shape: shape of input. |
| 89 |
|
""" |
| 90 |
|
# init |
| 91 |
|
shape = (1,) + shape # add batch axis |
| 92 |
|
foreground_weight = 1 - background_weight |
| 93 |
|
tf.random.set_seed(0) |
| 94 |
|
y_true = tf.random.uniform(shape=shape) |
| 95 |
|
y_pred = tf.random.uniform(shape=shape) |
| 96 |
|
|
| 97 |
|
# obtained value |
| 98 |
|
got = label.DiceScore( |
| 99 |
|
binary=binary, |
| 100 |
|
background_weight=background_weight, |
| 101 |
|
).call(y_true=y_true, y_pred=y_pred) |
| 102 |
|
|
| 103 |
|
# expected value |
| 104 |
|
flatten = tf.keras.layers.Flatten() |
| 105 |
|
y_true = flatten(y_true) |
| 106 |
|
y_pred = flatten(y_pred) |
| 107 |
|
if binary: |
| 108 |
|
y_true = tf.cast(y_true >= 0.5, dtype=y_true.dtype) |
| 109 |
|
y_pred = tf.cast(y_pred >= 0.5, dtype=y_pred.dtype) |
| 110 |
|
|
| 111 |
|
num = foreground_weight * tf.reduce_sum( |
| 112 |
|
y_true * y_pred, axis=1 |
| 113 |
|
) + background_weight * tf.reduce_sum((1 - y_true) * (1 - y_pred), axis=1) |
| 114 |
|
num *= 2 |
| 115 |
|
denom = foreground_weight * tf.reduce_sum( |
| 116 |
|
y_true + y_pred, axis=1 |
| 117 |
|
) + background_weight * tf.reduce_sum((1 - y_true) + (1 - y_pred), axis=1) |
| 118 |
|
expected = (num + EPS) / (denom + EPS) |
| 119 |
|
|
| 120 |
|
assert is_equal_tf(got, expected) |
| 121 |
|
|
| 122 |
|
def test_get_config(self): |
| 123 |
|
got = label.DiceScore().get_config() |