|
@@ 560-581 (lines=22) @@
|
| 557 |
|
assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) |
| 558 |
|
|
| 559 |
|
# TODO: distance problem |
| 560 |
|
def _test_search_distance_ip_after_index(self, connect, collection, get_simple_index): |
| 561 |
|
''' |
| 562 |
|
target: search collection, and check the result: distance |
| 563 |
|
method: compare the return distance value with value computed with Inner product |
| 564 |
|
expected: the return distance equals to the computed value |
| 565 |
|
''' |
| 566 |
|
index_type = get_simple_index["index_type"] |
| 567 |
|
nq = 2 |
| 568 |
|
metirc_type = "IP" |
| 569 |
|
entities, ids = init_data(connect, collection) |
| 570 |
|
get_simple_index["metric_type"] = metirc_type |
| 571 |
|
connect.create_index(collection, field_name, get_simple_index) |
| 572 |
|
search_param = get_search_param(index_type) |
| 573 |
|
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type, |
| 574 |
|
search_params=search_param) |
| 575 |
|
inside_vecs = entities[-1]["values"] |
| 576 |
|
max_distance = 0 |
| 577 |
|
for i in range(nb): |
| 578 |
|
tmp_dis = ip(vecs[0], inside_vecs[i]) |
| 579 |
|
if max_distance < tmp_dis: |
| 580 |
|
max_distance = tmp_dis |
| 581 |
|
res = connect.search(collection, query) |
| 582 |
|
assert abs(res[0]._distances[0] - max_distance) <= gen_inaccuracy(res[0]._distances[0]) |
| 583 |
|
|
| 584 |
|
# TODO: |
|
@@ 518-537 (lines=20) @@
|
| 515 |
|
assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) |
| 516 |
|
|
| 517 |
|
# TODO: distance problem |
| 518 |
|
def _test_search_distance_l2_after_index(self, connect, collection, get_simple_index): |
| 519 |
|
''' |
| 520 |
|
target: search collection, and check the result: distance |
| 521 |
|
method: compare the return distance value with value computed with Inner product |
| 522 |
|
expected: the return distance equals to the computed value |
| 523 |
|
''' |
| 524 |
|
index_type = get_simple_index["index_type"] |
| 525 |
|
nq = 2 |
| 526 |
|
entities, ids = init_data(connect, collection) |
| 527 |
|
connect.create_index(collection, field_name, get_simple_index) |
| 528 |
|
search_param = get_search_param(index_type) |
| 529 |
|
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param) |
| 530 |
|
inside_vecs = entities[-1]["values"] |
| 531 |
|
min_distance = 1.0 |
| 532 |
|
for i in range(nb): |
| 533 |
|
tmp_dis = l2(vecs[0], inside_vecs[i]) |
| 534 |
|
if min_distance > tmp_dis: |
| 535 |
|
min_distance = tmp_dis |
| 536 |
|
res = connect.search(collection, query) |
| 537 |
|
assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= gen_inaccuracy(res[0]._distances[0]) |
| 538 |
|
|
| 539 |
|
# TODO |
| 540 |
|
@pytest.mark.level(2) |