@@ 65-86 (lines=22) @@ | ||
62 | ) |
|
63 | assert is_equal_tf(got, expected) |
|
64 | ||
65 | def test_2d(self): |
|
66 | weights = tf.constant(np.array([[0.2], [0.3]], dtype=np.float32)) |
|
67 | values = tf.constant( |
|
68 | np.array( |
|
69 | [ |
|
70 | [1], # value at corner (0, 0), weight = 0.2 * 0.3 |
|
71 | [2], # value at corner (0, 1), weight = 0.2 * 0.7 |
|
72 | [3], # value at corner (1, 0), weight = 0.8 * 0.3 |
|
73 | [4], # value at corner (1, 1), weight = 0.8 * 0.7 |
|
74 | ], |
|
75 | dtype=np.float32, |
|
76 | ) |
|
77 | ) |
|
78 | # expected = 1 * 0.2 * 0.3 |
|
79 | # + 2 * 0.2 * 0.7 |
|
80 | # + 3 * 0.8 * 0.3 |
|
81 | # + 4 * 0.8 * 0.7 |
|
82 | expected = tf.constant(np.array([3.3], dtype=np.float32)) |
|
83 | got = layer_util.pyramid_combination( |
|
84 | values=values, weight_floor=weights, weight_ceil=1 - weights |
|
85 | ) |
|
86 | assert is_equal_tf(got, expected) |
|
87 | ||
88 | def test_error_dim(self): |
|
89 | weights = tf.constant(np.array([[[0.2]], [[0.2]]], dtype=np.float32)) |
|
@@ 54-63 (lines=10) @@ | ||
51 | ||
52 | ||
53 | class TestPyramidCombination: |
|
54 | def test_1d(self): |
|
55 | weights = tf.constant(np.array([[0.2]], dtype=np.float32)) |
|
56 | values = tf.constant(np.array([[1], [2]], dtype=np.float32)) |
|
57 | ||
58 | # expected = 1 * 0.2 + 2 * 2 |
|
59 | expected = tf.constant(np.array([1.8], dtype=np.float32)) |
|
60 | got = layer_util.pyramid_combination( |
|
61 | values=values, weight_floor=weights, weight_ceil=1 - weights |
|
62 | ) |
|
63 | assert is_equal_tf(got, expected) |
|
64 | ||
65 | def test_2d(self): |
|
66 | weights = tf.constant(np.array([[0.2], [0.3]], dtype=np.float32)) |