1
|
|
|
# coding=utf-8 |
2
|
|
|
|
3
|
|
|
""" |
4
|
|
|
Tests for deepreg/model/loss/deform.py in pytest style |
5
|
|
|
""" |
6
|
|
|
from test.unit.util import is_equal_tf |
7
|
|
|
|
8
|
|
|
import pytest |
9
|
|
|
import tensorflow as tf |
10
|
|
|
|
11
|
|
|
import deepreg.loss.deform as deform |
12
|
|
|
|
13
|
|
|
|
14
|
|
|
def test_gradient_dx(): |
15
|
|
|
"""test the calculation of gradient of a 3D images along x-axis""" |
16
|
|
|
tensor = tf.ones([4, 50, 50, 50]) |
17
|
|
|
get = deform.gradient_dx(tensor) |
18
|
|
|
expect = tf.zeros([4, 48, 48, 48]) |
19
|
|
|
assert is_equal_tf(get, expect) |
20
|
|
|
|
21
|
|
|
|
22
|
|
|
def test_gradient_dy(): |
23
|
|
|
"""test the calculation of gradient of a 3D images along y-axis""" |
24
|
|
|
tensor = tf.ones([4, 50, 50, 50]) |
25
|
|
|
get = deform.gradient_dy(tensor) |
26
|
|
|
expect = tf.zeros([4, 48, 48, 48]) |
27
|
|
|
assert is_equal_tf(get, expect) |
28
|
|
|
|
29
|
|
|
|
30
|
|
|
def test_gradient_dz(): |
31
|
|
|
"""test the calculation of gradient of a 3D images along z-axis""" |
32
|
|
|
tensor = tf.ones([4, 50, 50, 50]) |
33
|
|
|
get = deform.gradient_dz(tensor) |
34
|
|
|
expect = tf.zeros([4, 48, 48, 48]) |
35
|
|
|
assert is_equal_tf(get, expect) |
36
|
|
|
|
37
|
|
|
|
38
|
|
|
def test_gradient_dxyz(): |
39
|
|
|
"""test the calculation of gradient of a 3D images along xyz-axis""" |
40
|
|
|
# gradient_dx |
41
|
|
|
tensor = tf.ones([4, 50, 50, 50, 3]) |
42
|
|
|
get = deform.gradient_dxyz(tensor, deform.gradient_dx) |
43
|
|
|
expect = tf.zeros([4, 48, 48, 48, 3]) |
44
|
|
|
assert is_equal_tf(get, expect) |
45
|
|
|
|
46
|
|
|
# gradient_dy |
47
|
|
|
tensor = tf.ones([4, 50, 50, 50, 3]) |
48
|
|
|
get = deform.gradient_dxyz(tensor, deform.gradient_dy) |
49
|
|
|
expect = tf.zeros([4, 48, 48, 48, 3]) |
50
|
|
|
assert is_equal_tf(get, expect) |
51
|
|
|
|
52
|
|
|
# gradient_dz |
53
|
|
|
tensor = tf.ones([4, 50, 50, 50, 3]) |
54
|
|
|
get = deform.gradient_dxyz(tensor, deform.gradient_dz) |
55
|
|
|
expect = tf.zeros([4, 48, 48, 48, 3]) |
56
|
|
|
assert is_equal_tf(get, expect) |
57
|
|
|
|
58
|
|
|
|
59
|
|
|
class TestGradientNorm: |
60
|
|
|
@pytest.mark.parametrize("l1", [True, False]) |
61
|
|
|
def test_call(self, l1): |
62
|
|
|
tensor = tf.ones([4, 50, 50, 50, 3]) |
63
|
|
|
got = deform.GradientNorm(l1=l1)(tensor) |
64
|
|
|
expected = tf.zeros( |
65
|
|
|
[ |
66
|
|
|
4, |
67
|
|
|
] |
68
|
|
|
) |
69
|
|
|
assert is_equal_tf(got, expected) |
70
|
|
|
|
71
|
|
|
def test_get_config(self): |
72
|
|
|
got = deform.GradientNorm().get_config() |
73
|
|
|
expected = { |
74
|
|
|
"name": "GradientNorm", |
75
|
|
|
"l1": False, |
76
|
|
|
"dtype": "float32", |
77
|
|
|
"trainable": True, |
78
|
|
|
} |
79
|
|
|
assert got == expected |
80
|
|
|
|
81
|
|
|
|
82
|
|
|
def test_bending_energy(): |
83
|
|
|
"""test the calculation of bending energy""" |
84
|
|
|
tensor = tf.ones([4, 50, 50, 50, 3]) |
85
|
|
|
got = deform.BendingEnergy()(tensor) |
86
|
|
|
expected = tf.zeros( |
87
|
|
|
[ |
88
|
|
|
4, |
89
|
|
|
] |
90
|
|
|
) |
91
|
|
|
assert is_equal_tf(got, expected) |
92
|
|
|
|