Passed
Pull Request — main (#740)
by Yunguan
01:19
created

test.unit.test_loss_util.test_triangular_kernel1d()   A

Complexity

Conditions 2

Size

Total Lines 15
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 15
rs 9.95
c 0
b 0
f 0
cc 2
nop 1
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/loss/label.py in
5
pytest style
6
"""
7
8
from test.unit.util import is_equal_tf
9
10
import numpy as np
11
import pytest
12
import tensorflow as tf
13
14
from deepreg.loss.util import MultiScaleMixin, NegativeLossMixin, separable_filter
15
16
17
class TestMultiScaleLoss:
18
    def test_get_config(self):
19
        loss = MultiScaleMixin()
20
        got = loss.get_config()
21
        expected = dict(
22
            scales=None,
23
            kernel="gaussian",
24
            reduction=tf.keras.losses.Reduction.AUTO,
25
            name="MultiScaleMixin",
26
        )
27
        assert got == expected
28
29
30
def test_separable_filter():
31
    """
32
    Testing separable filter case where non
33
    zero length tensor is passed to the
34
    function.
35
    """
36
    k = np.ones((3, 3, 3, 3, 1), dtype=np.float32)
37
    array_eye = np.identity(3, dtype=np.float32)
38
    tensor_pred = np.zeros((3, 3, 3, 3, 1), dtype=np.float32)
39
    tensor_pred[:, :, 0, 0, 0] = array_eye
40
    tensor_pred = tf.convert_to_tensor(tensor_pred, dtype=tf.float32)
41
    k = tf.convert_to_tensor(k, dtype=tf.float32)
42
43
    expect = np.ones((3, 3, 3, 3, 1), dtype=np.float32)
44
    expect = tf.convert_to_tensor(expect, dtype=tf.float32)
45
46
    get = separable_filter(tensor_pred, k)
47
    assert is_equal_tf(get, expect)
48
49
50
class MinusClass(tf.keras.losses.Loss):
51
    def __init__(self):
52
        super().__init__()
53
        self.name = "MinusClass"
54
55
    def call(self, y_true, y_pred):
56
        return y_true - y_pred
57
58
59
class MinusClassLoss(NegativeLossMixin, MinusClass):
60
    pass
61
62
63
@pytest.mark.parametrize("y_true,y_pred,expected", [(1, 2, 1), (2, 1, -1), (0, 0, 0)])
64
def test_negative_loss_mixin(y_true, y_pred, expected):
65
    """
66
    Testing NegativeLossMixin class that
67
    inverts the sign of any value
68
    returned by a function
69
70
    :param y_true: int
71
    :param y_pred: int
72
    :param expected: int
73
    :return:
74
    """
75
76
    y_true = tf.constant(y_true, dtype=tf.float32)
77
    y_pred = tf.constant(y_pred, dtype=tf.float32)
78
79
    got = MinusClassLoss().call(
80
        y_true,
81
        y_pred,
82
    )
83
    assert is_equal_tf(got, expected)
84