Completed
Push — main ( de3728...ca54a2 )
by Yunguan
19s queued 13s
created

TestMultiScaleMixin.test_get_config()   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 10
rs 9.95
c 0
b 0
f 0
cc 1
nop 1
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/loss/label.py in
5
pytest style
6
"""
7
from test.unit.util import is_equal_tf
8
from typing import List, Optional, Union
9
10
import numpy as np
11
import pytest
12
import tensorflow as tf
13
14
from deepreg.loss.label import DiceLoss, DiceScore
15
from deepreg.loss.util import MultiScaleMixin, separable_filter
16
17
18
class TestMultiScaleMixin:
19
    def test_err(self):
20
        with pytest.raises(ValueError) as err_info:
21
            MultiScaleMixin(kernel="unknown")
22
        assert "Kernel unknown is not supported." in str(err_info.value)
23
24
    def test_get_config(self):
25
        loss = MultiScaleMixin()
26
        got = loss.get_config()
27
        expected = dict(
28
            scales=None,
29
            kernel="gaussian",
30
            reduction=tf.keras.losses.Reduction.AUTO,
31
            name=None,
32
        )
33
        assert got == expected
34
35
    @pytest.mark.parametrize("kernel", ["gaussian", "cauchy"])
36
    @pytest.mark.parametrize("scales", [None, 0, [0], [0, 1], [1, 2]])
37
    def test_call(self, kernel: str, scales: Optional[Union[List, float, int]]):
38
        """
39
        Test MultiScaleMixin using DiceLoss.
40
41
        :param kernel: kernel name.
42
        :param scales: scaling parameters.
43
        """
44
        shape = (2, 3, 4, 5)
45
        y_true = tf.random.uniform(shape=shape)
46
        y_pred = tf.random.uniform(shape=shape)
47
48
        loss = DiceLoss(kernel=kernel, scales=scales)
49
        loss.call(y_pred=y_pred, y_true=y_true)
50
51
52
def test_negative_loss_mixin():
53
    """Test DiceScore and DiceLoss have reversed sign."""
54
    shape = (2, 3, 4, 5)
55
    y_true = tf.random.uniform(shape=shape)
56
    y_pred = tf.random.uniform(shape=shape)
57
58
    dice_score = DiceScore().call(y_pred=y_pred, y_true=y_true)
59
    dice_loss = DiceLoss().call(y_pred=y_pred, y_true=y_true)
60
61
    assert is_equal_tf(dice_score, -dice_loss)
62
63
64
def test_separable_filter():
65
    """Testing separable filter case where diagonal ones are propagated."""
66
    k = tf.ones(shape=(3, 3, 3, 3, 1), dtype=tf.float32)
67
68
    array_eye = np.identity(3)
69
    x = np.zeros((3, 3, 3, 3, 1))
70
    x[:, :, 0, 0, 0] = array_eye
71
    x = tf.convert_to_tensor(x, dtype=tf.float32)
72
73
    expected = tf.ones(shape=(3, 3, 3, 3, 1), dtype=tf.float32)
74
    got = separable_filter(x, k)
75
76
    assert is_equal_tf(got, expected)
77