Completed
Push — main ( de3728...ca54a2 )
by Yunguan
19s queued 13s
created

test.unit.test_loss_image   A

Complexity

Total Complexity 12

Size/Duplication

Total Lines 244
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 12
eloc 155
dl 0
loc 244
rs 10
c 0
b 0
f 0

9 Methods

Rating   Name   Duplication   Size   Complexity  
A TestGlobalMutualInformation.test_get_config() 0 9 1
A TestLocalNormalizedCrossCorrelation.test_input_shape() 0 23 1
A TestLocalNormalizedCrossCorrelation.test_input_shape_err() 0 20 2
A TestGlobalNormalizedCrossCorrelation.test_output() 0 26 1
A TestLocalNormalizedCrossCorrelation.test_smooth() 0 43 1
A TestLocalNormalizedCrossCorrelation.test_kernel_error() 0 5 2
A TestLocalNormalizedCrossCorrelation.test_exact_value() 0 49 1
A TestLocalNormalizedCrossCorrelation.test_get_config() 0 12 1
A TestGlobalMutualInformation.test_zero_info() 0 17 1

1 Function

Rating   Name   Duplication   Size   Complexity  
A test_kernel_fn() 0 6 1
1
"""
2
Tests for deepreg/model/loss/image.py in
3
pytest style.
4
Notes: The format of inputs to the function dissimilarity_fn
5
in image.py should be better converted into tf tensor type beforehand.
6
"""
7
8
from test.unit.util import is_equal_tf
9
from typing import Tuple
10
11
import numpy as np
12
import pytest
13
import tensorflow as tf
14
15
import deepreg.loss.image as image
16
from deepreg.constant import EPS
17
18
19
class TestGlobalMutualInformation:
20
    @pytest.mark.parametrize(
21
        "y_true,y_pred,shape,expected",
22
        [
23
            (0.6, 0.3, (3, 3, 3, 3), 0.0),
24
            (0.6, 0.3, (3, 3, 3, 3, 3), 0.0),
25
            (0.0, 1.0, (3, 3, 3, 3, 3), 0.0),
26
        ],
27
    )
28
    def test_zero_info(self, y_true, y_pred, shape, expected):
29
        y_true = y_true * np.ones(shape=shape)
30
        y_pred = y_pred * np.ones(shape=shape)
31
        expected = expected * np.ones(shape=(shape[0],))
32
        got = image.GlobalMutualInformation().call(
33
            y_true,
34
            y_pred,
35
        )
36
        assert is_equal_tf(got, expected)
37
38
    def test_get_config(self):
39
        got = image.GlobalMutualInformation().get_config()
40
        expected = dict(
41
            num_bins=23,
42
            sigma_ratio=0.5,
43
            reduction=tf.keras.losses.Reduction.AUTO,
44
            name="GlobalMutualInformation",
45
        )
46
        assert got == expected
47
48
49
@pytest.mark.parametrize("kernel_size", [3, 5, 7])
50
@pytest.mark.parametrize("name", ["gaussian", "triangular", "rectangular"])
51
def test_kernel_fn(kernel_size, name):
52
    kernel_fn = image.LocalNormalizedCrossCorrelation.kernel_fn_dict[name]
53
    filters = kernel_fn(kernel_size)
54
    assert filters.shape == (kernel_size,)
55
56
57
class TestLocalNormalizedCrossCorrelation:
58
    @pytest.mark.parametrize(
59
        ("y_true_shape", "y_pred_shape"),
60
        [
61
            ((2, 3, 4, 5), (2, 3, 4, 5)),
62
            ((2, 3, 4, 5), (2, 3, 4, 5, 1)),
63
            ((2, 3, 4, 5, 1), (2, 3, 4, 5)),
64
            ((2, 3, 4, 5, 1), (2, 3, 4, 5, 1)),
65
        ],
66
    )
67
    def test_input_shape(self, y_true_shape: Tuple, y_pred_shape: Tuple):
68
        """
69
        Test input with / without channel axis.
70
71
        :param y_true_shape: input shape for y_true.
72
        :param y_pred_shape: input shape for y_pred.
73
        """
74
        y_true = tf.ones(shape=y_true_shape)
75
        y_pred = tf.ones(shape=y_pred_shape)
76
        got = image.LocalNormalizedCrossCorrelation().call(
77
            y_true,
78
            y_pred,
79
        )
80
        assert got.shape == y_true_shape[:1]
81
82
    @pytest.mark.parametrize(
83
        ("y_true_shape", "y_pred_shape", "name"),
84
        [
85
            ((2, 3, 4, 5), (2, 3, 4, 5, 6), "y_pred"),
86
            ((2, 3, 4, 5, 6), (2, 3, 4, 5), "y_true"),
87
        ],
88
    )
89
    def test_input_shape_err(self, y_true_shape: Tuple, y_pred_shape: Tuple, name: str):
90
        """
91
        Current LNCC does not support image having channel dimension > 1.
92
93
        :param y_true_shape: input shape for y_true.
94
        :param y_pred_shape: input shape for y_pred.
95
        :param name: name of the tensor having error.
96
        """
97
        y_true = tf.ones(shape=y_true_shape)
98
        y_pred = tf.ones(shape=y_pred_shape)
99
        with pytest.raises(ValueError) as err_info:
100
            image.LocalNormalizedCrossCorrelation().call(y_true, y_pred)
101
        assert f"Last dimension of {name} is not one." in str(err_info.value)
102
103
    @pytest.mark.parametrize("value", [0.0, 0.5, 1.0])
104
    @pytest.mark.parametrize(
105
        ("smooth_nr", "smooth_dr", "expected"),
106
        [
107
            (1e-5, 1e-5, 1),
108
            (0, 1e-5, 0),
109
            (1e-5, 0, np.inf),
110
            (0, 0, np.nan),
111
            (1e-7, 1e-7, 1),
112
        ],
113
    )
114
    def test_smooth(
115
        self,
116
        value: float,
117
        smooth_nr: float,
118
        smooth_dr: float,
119
        expected: float,
120
    ):
121
        """
122
        Test values in extreme cases where variances are all zero.
123
124
        :param value: value for input.
125
        :param smooth_nr: constant for numerator.
126
        :param smooth_dr: constant for denominator.
127
        :param expected: target value.
128
        """
129
        kernel_size = 5
130
        mid = kernel_size // 2
131
        shape = (1, kernel_size, kernel_size, kernel_size, 1)
132
        y_true = tf.ones(shape=shape) * value
133
        y_pred = tf.ones(shape=shape) * value
134
135
        got = image.LocalNormalizedCrossCorrelation(
136
            kernel_size=kernel_size,
137
            smooth_nr=smooth_nr,
138
            smooth_dr=smooth_dr,
139
        ).calc_ncc(
140
            y_true,
141
            y_pred,
142
        )
143
        got = got[0, mid, mid, mid, 0]
144
        expected = tf.constant(expected)
145
        assert is_equal_tf(got, expected)
146
147
    @pytest.mark.parametrize(
148
        "kernel_type",
149
        ["rectangular", "gaussian", "triangular"],
150
    )
151
    @pytest.mark.parametrize(
152
        "kernel_size",
153
        [3, 5, 7],
154
    )
155
    def test_exact_value(self, kernel_type, kernel_size):
156
        """
157
        Test the exact value at the center of a cube.
158
159
        :param kernel_type: name of kernel.
160
        :param kernel_size: size of the kernel and the cube.
161
        """
162
        # init
163
        mid = kernel_size // 2
164
        tf.random.set_seed(0)
165
        y_true = tf.random.uniform(shape=(1, kernel_size, kernel_size, kernel_size, 1))
166
        y_pred = tf.random.uniform(shape=(1, kernel_size, kernel_size, kernel_size, 1))
167
        loss = image.LocalNormalizedCrossCorrelation(
168
            kernel_type=kernel_type, kernel_size=kernel_size
169
        )
170
171
        # obtained value
172
        got = loss.calc_ncc(y_true=y_true, y_pred=y_pred)
173
        got = got[0, mid, mid, mid, 0]  # center voxel
174
175
        # target value
176
        kernel_3d = (
177
            loss.kernel[:, None, None]
178
            * loss.kernel[None, :, None]
179
            * loss.kernel[None, None, :]
180
        )
181
        kernel_3d = kernel_3d[None, :, :, :, None]
182
183
        y_true_mean = tf.reduce_sum(y_true * kernel_3d) / loss.kernel_vol
184
        y_true_normalized = y_true - y_true_mean
185
        y_true_var = tf.reduce_sum(y_true_normalized ** 2 * kernel_3d)
186
187
        y_pred_mean = tf.reduce_sum(y_pred * kernel_3d) / loss.kernel_vol
188
        y_pred_normalized = y_pred - y_pred_mean
189
        y_pred_var = tf.reduce_sum(y_pred_normalized ** 2 * kernel_3d)
190
191
        cross = tf.reduce_sum(y_true_normalized * y_pred_normalized * kernel_3d)
192
        expected = (cross ** 2 + EPS) / (y_pred_var * y_true_var + EPS)
193
194
        # check
195
        assert is_equal_tf(got, expected)
196
197
    def test_kernel_error(self):
198
        """Test the error message when using wrong kernel."""
199
        with pytest.raises(ValueError) as err_info:
200
            image.LocalNormalizedCrossCorrelation(kernel_type="constant")
201
        assert "Wrong kernel_type constant for LNCC loss type." in str(err_info.value)
202
203
    def test_get_config(self):
204
        """Test the config is saved correctly."""
205
        got = image.LocalNormalizedCrossCorrelation().get_config()
206
        expected = dict(
207
            kernel_size=9,
208
            kernel_type="rectangular",
209
            reduction=tf.keras.losses.Reduction.AUTO,
210
            name="LocalNormalizedCrossCorrelation",
211
            smooth_nr=1e-5,
212
            smooth_dr=1e-5,
213
        )
214
        assert got == expected
215
216
217
class TestGlobalNormalizedCrossCorrelation:
218
    @pytest.mark.parametrize(
219
        "y_true,y_pred,shape,expected",
220
        [
221
            (0.6, 0.3, (3, 3), 1),
222
            (0.6, 0.3, (3, 3, 3), 1),
223
            (0.6, -0.3, (3, 3, 3), 1),
224
            (0.6, 0.3, (3, 3, 3, 3), 1),
225
        ],
226
    )
227
    def test_output(self, y_true, y_pred, shape, expected):
228
229
        y_true = y_true * tf.ones(shape=shape)
230
        y_pred = y_pred * tf.ones(shape=shape)
231
232
        pad_width = tuple([(0, 0)] + [(1, 1)] * (len(shape) - 1))
233
        y_true = np.pad(y_true, pad_width=pad_width)
234
        y_pred = np.pad(y_pred, pad_width=pad_width)
235
236
        got = image.GlobalNormalizedCrossCorrelation().call(
237
            y_true,
238
            y_pred,
239
        )
240
241
        expected = expected * tf.ones(shape=(shape[0],))
242
243
        assert is_equal_tf(got, expected)
244