1
|
|
|
from typing import List, Union |
2
|
|
|
|
3
|
|
|
import numpy as np |
4
|
|
|
import tensorflow as tf |
5
|
|
|
|
6
|
|
|
from deepreg.constant import EPS |
7
|
|
|
|
8
|
|
|
|
9
|
|
|
def is_equal_np( |
10
|
|
|
x: Union[np.ndarray, List], y: Union[np.ndarray, List], atol: float = EPS |
11
|
|
|
) -> bool: |
12
|
|
|
""" |
13
|
|
|
Check if two numpy arrays are identical within a tolerance. |
14
|
|
|
|
15
|
|
|
:param x: |
16
|
|
|
:param y: |
17
|
|
|
:param atol: error margin |
18
|
|
|
:return: return true if two tf tensors are nearly equal |
19
|
|
|
""" |
20
|
|
|
x = np.asarray(x, dtype=np.float32) |
21
|
|
|
y = np.asarray(y, dtype=np.float32) |
22
|
|
|
|
23
|
|
|
# check shape |
24
|
|
|
if x.shape != y.shape: |
25
|
|
|
return False |
26
|
|
|
|
27
|
|
|
# check nan values |
28
|
|
|
# support case some values are nan |
29
|
|
|
if np.any(np.isnan(x) != np.isnan(y)): |
30
|
|
|
return False |
31
|
|
|
x = np.nan_to_num(x) |
32
|
|
|
y = np.nan_to_num(y) |
33
|
|
|
|
34
|
|
|
# check values |
35
|
|
|
return np.all(np.isclose(x, y, atol=atol)) |
36
|
|
|
|
37
|
|
|
|
38
|
|
|
def is_equal_tf( |
39
|
|
|
x: Union[tf.Tensor, np.ndarray, List], |
40
|
|
|
y: Union[tf.Tensor, np.ndarray, List], |
41
|
|
|
atol: float = EPS, |
42
|
|
|
) -> bool: |
43
|
|
|
""" |
44
|
|
|
Check if two tf tensors are identical within a tolerance. |
45
|
|
|
|
46
|
|
|
:param x: |
47
|
|
|
:param y: |
48
|
|
|
:param atol: error margin |
49
|
|
|
:return: return true if two tf tensors are nearly equal |
50
|
|
|
""" |
51
|
|
|
x = tf.cast(x, dtype=tf.float32).numpy() |
52
|
|
|
y = tf.cast(y, dtype=tf.float32).numpy() |
53
|
|
|
return is_equal_np(x=x, y=y, atol=atol) |
54
|
|
|
|