test.unit.test_loss_deform   A
last analyzed

Complexity

Total Complexity 7

Size/Duplication

Total Lines 92
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 7
eloc 57
dl 0
loc 92
rs 10
c 0
b 0
f 0

5 Functions

Rating   Name   Duplication   Size   Complexity  
A test_gradient_dz() 0 6 1
A test_gradient_dx() 0 6 1
A test_gradient_dy() 0 6 1
A test_gradient_dxyz() 0 19 1
A test_bending_energy() 0 10 1

2 Methods

Rating   Name   Duplication   Size   Complexity  
A TestGradientNorm.test_get_config() 0 9 1
A TestGradientNorm.test_call() 0 10 1
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/loss/deform.py in pytest style
5
"""
6
from test.unit.util import is_equal_tf
7
8
import pytest
9
import tensorflow as tf
10
11
import deepreg.loss.deform as deform
12
13
14
def test_gradient_dx():
15
    """test the calculation of gradient of a 3D images along x-axis"""
16
    tensor = tf.ones([4, 50, 50, 50])
17
    get = deform.gradient_dx(tensor)
18
    expect = tf.zeros([4, 48, 48, 48])
19
    assert is_equal_tf(get, expect)
20
21
22
def test_gradient_dy():
23
    """test the calculation of gradient of a 3D images along y-axis"""
24
    tensor = tf.ones([4, 50, 50, 50])
25
    get = deform.gradient_dy(tensor)
26
    expect = tf.zeros([4, 48, 48, 48])
27
    assert is_equal_tf(get, expect)
28
29
30
def test_gradient_dz():
31
    """test the calculation of gradient of a 3D images along z-axis"""
32
    tensor = tf.ones([4, 50, 50, 50])
33
    get = deform.gradient_dz(tensor)
34
    expect = tf.zeros([4, 48, 48, 48])
35
    assert is_equal_tf(get, expect)
36
37
38
def test_gradient_dxyz():
39
    """test the calculation of gradient of a 3D images along xyz-axis"""
40
    # gradient_dx
41
    tensor = tf.ones([4, 50, 50, 50, 3])
42
    get = deform.gradient_dxyz(tensor, deform.gradient_dx)
43
    expect = tf.zeros([4, 48, 48, 48, 3])
44
    assert is_equal_tf(get, expect)
45
46
    # gradient_dy
47
    tensor = tf.ones([4, 50, 50, 50, 3])
48
    get = deform.gradient_dxyz(tensor, deform.gradient_dy)
49
    expect = tf.zeros([4, 48, 48, 48, 3])
50
    assert is_equal_tf(get, expect)
51
52
    # gradient_dz
53
    tensor = tf.ones([4, 50, 50, 50, 3])
54
    get = deform.gradient_dxyz(tensor, deform.gradient_dz)
55
    expect = tf.zeros([4, 48, 48, 48, 3])
56
    assert is_equal_tf(get, expect)
57
58
59
class TestGradientNorm:
60
    @pytest.mark.parametrize("l1", [True, False])
61
    def test_call(self, l1):
62
        tensor = tf.ones([4, 50, 50, 50, 3])
63
        got = deform.GradientNorm(l1=l1)(tensor)
64
        expected = tf.zeros(
65
            [
66
                4,
67
            ]
68
        )
69
        assert is_equal_tf(got, expected)
70
71
    def test_get_config(self):
72
        got = deform.GradientNorm().get_config()
73
        expected = {
74
            "name": "GradientNorm",
75
            "l1": False,
76
            "dtype": "float32",
77
            "trainable": True,
78
        }
79
        assert got == expected
80
81
82
def test_bending_energy():
83
    """test the calculation of bending energy"""
84
    tensor = tf.ones([4, 50, 50, 50, 3])
85
    got = deform.BendingEnergy()(tensor)
86
    expected = tf.zeros(
87
        [
88
            4,
89
        ]
90
    )
91
    assert is_equal_tf(got, expected)
92