test.unit.util   A
last analyzed

Complexity

Total Complexity 4

Size/Duplication

Total Lines 54
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 4
eloc 23
dl 0
loc 54
rs 10
c 0
b 0
f 0

2 Functions

Rating   Name   Duplication   Size   Complexity  
A is_equal_tf() 0 16 1
A is_equal_np() 0 27 3
1
from typing import List, Union
2
3
import numpy as np
4
import tensorflow as tf
5
6
from deepreg.constant import EPS
7
8
9
def is_equal_np(
10
    x: Union[np.ndarray, List], y: Union[np.ndarray, List], atol: float = EPS
11
) -> bool:
12
    """
13
    Check if two numpy arrays are identical within a tolerance.
14
15
    :param x:
16
    :param y:
17
    :param atol: error margin
18
    :return: return true if two tf tensors are nearly equal
19
    """
20
    x = np.asarray(x, dtype=np.float32)
21
    y = np.asarray(y, dtype=np.float32)
22
23
    # check shape
24
    if x.shape != y.shape:
25
        return False
26
27
    # check nan values
28
    # support case some values are nan
29
    if np.any(np.isnan(x) != np.isnan(y)):
30
        return False
31
    x = np.nan_to_num(x)
32
    y = np.nan_to_num(y)
33
34
    # check values
35
    return np.all(np.isclose(x, y, atol=atol))
36
37
38
def is_equal_tf(
39
    x: Union[tf.Tensor, np.ndarray, List],
40
    y: Union[tf.Tensor, np.ndarray, List],
41
    atol: float = EPS,
42
) -> bool:
43
    """
44
    Check if two tf tensors are identical within a tolerance.
45
46
    :param x:
47
    :param y:
48
    :param atol: error margin
49
    :return: return true if two tf tensors are nearly equal
50
    """
51
    x = tf.cast(x, dtype=tf.float32).numpy()
52
    y = tf.cast(y, dtype=tf.float32).numpy()
53
    return is_equal_np(x=x, y=y, atol=atol)
54