Passed
Pull Request — main (#605)
by
unknown
03:40
created

test.unit.test_loss_util   A

Complexity

Total Complexity 3

Size/Duplication

Total Lines 53
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 3
eloc 31
dl 0
loc 53
rs 10
c 0
b 0
f 0

3 Functions

Rating   Name   Duplication   Size   Complexity  
A test_separable_filter() 0 18 1
A test_gaussian_kernel1d() 0 7 1
A test_cauchy_kernel1d() 0 7 1
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/model/loss/label.py in
5
pytest style
6
"""
7
8
from test.unit.util import is_equal_tf
9
10
import numpy as np
11
import pytest
12
import tensorflow as tf
13
14
import deepreg.model.loss.label as label
15
16
17
@pytest.mark.parametrize("sigma", [1, 3, 2.2])
18
def test_gaussian_kernel1d(sigma):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
19
    tail = int(sigma * 3)
20
    expected = [np.exp(-0.5 * x ** 2 / sigma ** 2) for x in range(-tail, tail + 1)]
21
    expected = expected / np.sum(expected)
22
    got = label.gaussian_kernel1d(sigma)
23
    assert is_equal_tf(got, expected)
24
25
26
@pytest.mark.parametrize("sigma", [1, 3, 2.2])
27
def test_cauchy_kernel1d(sigma):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
28
    tail = int(sigma * 5)
29
    expected = [1 / ((x / sigma) ** 2 + 1) for x in range(-tail, tail + 1)]
30
    expected = expected / np.sum(expected)
31
    got = label.cauchy_kernel1d(sigma)
32
    assert is_equal_tf(got, expected)
33
34
35
def test_separable_filter():
36
    """
37
    Testing separable filter case where non
38
    zero length tensor is passed to the
39
    function.
40
    """
41
    k = np.ones((3, 3, 3, 3, 1), dtype=np.float32)
42
    array_eye = np.identity(3, dtype=np.float32)
43
    tensor_pred = np.zeros((3, 3, 3, 3, 1), dtype=np.float32)
44
    tensor_pred[:, :, 0, 0, 0] = array_eye
45
    tensor_pred = tf.convert_to_tensor(tensor_pred, dtype=tf.float32)
46
    k = tf.convert_to_tensor(k, dtype=tf.float32)
47
48
    expect = np.ones((3, 3, 3, 3, 1), dtype=np.float32)
49
    expect = tf.convert_to_tensor(expect, dtype=tf.float32)
50
51
    get = label.separable_filter(tensor_pred, k)
52
    assert is_equal_tf(get, expect)
53