@@ 64-85 (lines=22) @@ | ||
61 | ) |
|
62 | assert is_equal_tf(got, expected) |
|
63 | ||
64 | def test_2d(self): |
|
65 | weights = tf.constant(np.array([[0.2], [0.3]], dtype=np.float32)) |
|
66 | values = tf.constant( |
|
67 | np.array( |
|
68 | [ |
|
69 | [1], # value at corner (0, 0), weight = 0.2 * 0.3 |
|
70 | [2], # value at corner (0, 1), weight = 0.2 * 0.7 |
|
71 | [3], # value at corner (1, 0), weight = 0.8 * 0.3 |
|
72 | [4], # value at corner (1, 1), weight = 0.8 * 0.7 |
|
73 | ], |
|
74 | dtype=np.float32, |
|
75 | ) |
|
76 | ) |
|
77 | # expected = 1 * 0.2 * 0.3 |
|
78 | # + 2 * 0.2 * 0.7 |
|
79 | # + 3 * 0.8 * 0.3 |
|
80 | # + 4 * 0.8 * 0.7 |
|
81 | expected = tf.constant(np.array([3.3], dtype=np.float32)) |
|
82 | got = layer_util.pyramid_combination( |
|
83 | values=values, weight_floor=weights, weight_ceil=1 - weights |
|
84 | ) |
|
85 | assert is_equal_tf(got, expected) |
|
86 | ||
87 | def test_error_dim(self): |
|
88 | weights = tf.constant(np.array([[[0.2]], [[0.2]]], dtype=np.float32)) |
|
@@ 53-62 (lines=10) @@ | ||
50 | ||
51 | ||
52 | class TestPyramidCombination: |
|
53 | def test_1d(self): |
|
54 | weights = tf.constant(np.array([[0.2]], dtype=np.float32)) |
|
55 | values = tf.constant(np.array([[1], [2]], dtype=np.float32)) |
|
56 | ||
57 | # expected = 1 * 0.2 + 2 * 2 |
|
58 | expected = tf.constant(np.array([1.8], dtype=np.float32)) |
|
59 | got = layer_util.pyramid_combination( |
|
60 | values=values, weight_floor=weights, weight_ceil=1 - weights |
|
61 | ) |
|
62 | assert is_equal_tf(got, expected) |
|
63 | ||
64 | def test_2d(self): |
|
65 | weights = tf.constant(np.array([[0.2], [0.3]], dtype=np.float32)) |