| Total Complexity | 201 |
| Total Lines | 1680 |
| Duplicated Lines | 19.11 % |
| Changes | 0 | ||
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like test_search often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | import time |
||
| 2 | import pdb |
||
| 3 | import copy |
||
| 4 | import threading |
||
| 5 | import logging |
||
| 6 | from multiprocessing import Pool, Process |
||
| 7 | import pytest |
||
| 8 | import numpy as np |
||
| 9 | |||
| 10 | from milvus import DataType |
||
| 11 | from utils import * |
||
| 12 | from constants import * |
||
| 13 | |||
| 14 | uid = "test_search" |
||
| 15 | nq = 1 |
||
| 16 | epsilon = 0.001 |
||
| 17 | field_name = default_float_vec_field_name |
||
| 18 | binary_field_name = default_binary_vec_field_name |
||
| 19 | search_param = {"nprobe": 1} |
||
| 20 | |||
| 21 | entity = gen_entities(1, is_normal=True) |
||
| 22 | entities = gen_entities(default_nb, is_normal=True) |
||
| 23 | raw_vectors, binary_entities = gen_binary_entities(default_nb) |
||
| 24 | default_query, default_query_vecs = gen_query_vectors(field_name, entities, default_top_k, nq) |
||
| 25 | default_binary_query, default_binary_query_vecs = gen_query_vectors(binary_field_name, binary_entities, default_top_k, nq) |
||
| 26 | |||
| 27 | View Code Duplication | def init_data(connect, collection, nb=1200, partition_tags=None, auto_id=True): |
|
|
|
|||
| 28 | ''' |
||
| 29 | Generate entities and add it in collection |
||
| 30 | ''' |
||
| 31 | global entities |
||
| 32 | if nb == 1200: |
||
| 33 | insert_entities = entities |
||
| 34 | else: |
||
| 35 | insert_entities = gen_entities(nb, is_normal=True) |
||
| 36 | if partition_tags is None: |
||
| 37 | if auto_id: |
||
| 38 | ids = connect.insert(collection, insert_entities) |
||
| 39 | else: |
||
| 40 | ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)]) |
||
| 41 | else: |
||
| 42 | if auto_id: |
||
| 43 | ids = connect.insert(collection, insert_entities, partition_tag=partition_tags) |
||
| 44 | else: |
||
| 45 | ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags) |
||
| 46 | connect.flush([collection]) |
||
| 47 | return insert_entities, ids |
||
| 48 | |||
| 49 | |||
| 50 | def init_binary_data(connect, collection, nb=1200, insert=True, partition_tags=None): |
||
| 51 | ''' |
||
| 52 | Generate entities and add it in collection |
||
| 53 | ''' |
||
| 54 | ids = [] |
||
| 55 | global binary_entities |
||
| 56 | global raw_vectors |
||
| 57 | if nb == 1200: |
||
| 58 | insert_entities = binary_entities |
||
| 59 | insert_raw_vectors = raw_vectors |
||
| 60 | else: |
||
| 61 | insert_raw_vectors, insert_entities = gen_binary_entities(nb) |
||
| 62 | if insert is True: |
||
| 63 | if partition_tags is None: |
||
| 64 | ids = connect.insert(collection, insert_entities) |
||
| 65 | else: |
||
| 66 | ids = connect.insert(collection, insert_entities, partition_tag=partition_tags) |
||
| 67 | connect.flush([collection]) |
||
| 68 | return insert_raw_vectors, insert_entities, ids |
||
| 69 | |||
| 70 | |||
| 71 | class TestSearchBase: |
||
| 72 | """ |
||
| 73 | generate valid create_index params |
||
| 74 | """ |
||
| 75 | |||
| 76 | @pytest.fixture( |
||
| 77 | scope="function", |
||
| 78 | params=gen_index() |
||
| 79 | ) |
||
| 80 | def get_index(self, request, connect): |
||
| 81 | if str(connect._cmd("mode")) == "CPU": |
||
| 82 | if request.param["index_type"] in index_cpu_not_support(): |
||
| 83 | pytest.skip("sq8h not support in CPU mode") |
||
| 84 | return request.param |
||
| 85 | |||
| 86 | @pytest.fixture( |
||
| 87 | scope="function", |
||
| 88 | params=gen_simple_index() |
||
| 89 | ) |
||
| 90 | def get_simple_index(self, request, connect): |
||
| 91 | if str(connect._cmd("mode")) == "CPU": |
||
| 92 | if request.param["index_type"] in index_cpu_not_support(): |
||
| 93 | pytest.skip("sq8h not support in CPU mode") |
||
| 94 | return request.param |
||
| 95 | |||
| 96 | @pytest.fixture( |
||
| 97 | scope="function", |
||
| 98 | params=gen_binary_index() |
||
| 99 | ) |
||
| 100 | def get_jaccard_index(self, request, connect): |
||
| 101 | logging.getLogger().info(request.param) |
||
| 102 | if request.param["index_type"] in binary_support(): |
||
| 103 | return request.param |
||
| 104 | else: |
||
| 105 | pytest.skip("Skip index Temporary") |
||
| 106 | |||
| 107 | @pytest.fixture( |
||
| 108 | scope="function", |
||
| 109 | params=gen_binary_index() |
||
| 110 | ) |
||
| 111 | def get_hamming_index(self, request, connect): |
||
| 112 | logging.getLogger().info(request.param) |
||
| 113 | if request.param["index_type"] in binary_support(): |
||
| 114 | return request.param |
||
| 115 | else: |
||
| 116 | pytest.skip("Skip index Temporary") |
||
| 117 | |||
| 118 | @pytest.fixture( |
||
| 119 | scope="function", |
||
| 120 | params=gen_binary_index() |
||
| 121 | ) |
||
| 122 | def get_structure_index(self, request, connect): |
||
| 123 | logging.getLogger().info(request.param) |
||
| 124 | if request.param["index_type"] == "FLAT": |
||
| 125 | return request.param |
||
| 126 | else: |
||
| 127 | pytest.skip("Skip index Temporary") |
||
| 128 | |||
| 129 | """ |
||
| 130 | generate top-k params |
||
| 131 | """ |
||
| 132 | |||
| 133 | @pytest.fixture( |
||
| 134 | scope="function", |
||
| 135 | params=[1, 10] |
||
| 136 | ) |
||
| 137 | def get_top_k(self, request): |
||
| 138 | yield request.param |
||
| 139 | |||
| 140 | @pytest.fixture( |
||
| 141 | scope="function", |
||
| 142 | params=[1, 10, 1100] |
||
| 143 | ) |
||
| 144 | def get_nq(self, request): |
||
| 145 | yield request.param |
||
| 146 | |||
| 147 | View Code Duplication | def test_search_flat(self, connect, collection, get_top_k, get_nq): |
|
| 148 | ''' |
||
| 149 | target: test basic search function, all the search params is corrent, change top-k value |
||
| 150 | method: search with the given vectors, check the result |
||
| 151 | expected: the length of the result is top_k |
||
| 152 | ''' |
||
| 153 | top_k = get_top_k |
||
| 154 | nq = get_nq |
||
| 155 | entities, ids = init_data(connect, collection) |
||
| 156 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
| 157 | if top_k <= max_top_k: |
||
| 158 | res = connect.search(collection, query) |
||
| 159 | assert len(res[0]) == top_k |
||
| 160 | assert res[0]._distances[0] <= epsilon |
||
| 161 | assert check_id_result(res[0], ids[0]) |
||
| 162 | else: |
||
| 163 | with pytest.raises(Exception) as e: |
||
| 164 | res = connect.search(collection, query) |
||
| 165 | |||
| 166 | View Code Duplication | def test_search_flat_top_k(self, connect, collection, get_nq): |
|
| 167 | ''' |
||
| 168 | target: test basic search function, all the search params is corrent, change top-k value |
||
| 169 | method: search with the given vectors, check the result |
||
| 170 | expected: the length of the result is top_k |
||
| 171 | ''' |
||
| 172 | top_k = 16385 |
||
| 173 | nq = get_nq |
||
| 174 | entities, ids = init_data(connect, collection) |
||
| 175 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
| 176 | if top_k <= max_top_k: |
||
| 177 | res = connect.search(collection, query) |
||
| 178 | assert len(res[0]) == top_k |
||
| 179 | assert res[0]._distances[0] <= epsilon |
||
| 180 | assert check_id_result(res[0], ids[0]) |
||
| 181 | else: |
||
| 182 | with pytest.raises(Exception) as e: |
||
| 183 | res = connect.search(collection, query) |
||
| 184 | |||
| 185 | def test_search_field(self, connect, collection, get_top_k, get_nq): |
||
| 186 | ''' |
||
| 187 | target: test basic search function, all the search params is corrent, change top-k value |
||
| 188 | method: search with the given vectors, check the result |
||
| 189 | expected: the length of the result is top_k |
||
| 190 | ''' |
||
| 191 | top_k = get_top_k |
||
| 192 | nq = get_nq |
||
| 193 | entities, ids = init_data(connect, collection) |
||
| 194 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
| 195 | if top_k <= max_top_k: |
||
| 196 | res = connect.search(collection, query, fields=["float_vector"]) |
||
| 197 | assert len(res[0]) == top_k |
||
| 198 | assert res[0]._distances[0] <= epsilon |
||
| 199 | assert check_id_result(res[0], ids[0]) |
||
| 200 | res = connect.search(collection, query, fields=["float"]) |
||
| 201 | for i in range(nq): |
||
| 202 | assert entities[1]["values"][:nq][i] in [r.entity.get('float') for r in res[i]] |
||
| 203 | else: |
||
| 204 | with pytest.raises(Exception): |
||
| 205 | connect.search(collection, query) |
||
| 206 | |||
| 207 | def test_search_after_delete(self, connect, collection, get_top_k, get_nq): |
||
| 208 | ''' |
||
| 209 | target: test basic search function before and after deletion, all the search params is |
||
| 210 | corrent, change top-k value. |
||
| 211 | check issue <a href="https://github.com/milvus-io/milvus/issues/4200">#4200</a> |
||
| 212 | method: search with the given vectors, check the result |
||
| 213 | expected: the deleted entities do not exist in the result. |
||
| 214 | ''' |
||
| 215 | top_k = get_top_k |
||
| 216 | nq = get_nq |
||
| 217 | |||
| 218 | entities, ids = init_data(connect, collection, nb=10000) |
||
| 219 | first_int64_value = entities[0]["values"][0] |
||
| 220 | first_vector = entities[2]["values"][0] |
||
| 221 | |||
| 222 | search_param = get_search_param("FLAT") |
||
| 223 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
| 224 | vecs[:] = [] |
||
| 225 | vecs.append(first_vector) |
||
| 226 | |||
| 227 | res = None |
||
| 228 | if top_k > max_top_k: |
||
| 229 | with pytest.raises(Exception): |
||
| 230 | connect.search(collection, query, fields=['int64']) |
||
| 231 | pytest.skip("top_k value is larger than max_topp_k") |
||
| 232 | else: |
||
| 233 | res = connect.search(collection, query, fields=['int64']) |
||
| 234 | assert len(res) == 1 |
||
| 235 | assert len(res[0]) >= top_k |
||
| 236 | assert res[0][0].id == ids[0] |
||
| 237 | assert res[0][0].entity.get("int64") == first_int64_value |
||
| 238 | assert res[0]._distances[0] < epsilon |
||
| 239 | assert check_id_result(res[0], ids[0]) |
||
| 240 | |||
| 241 | connect.delete_entity_by_id(collection, ids[:1]) |
||
| 242 | connect.flush([collection]) |
||
| 243 | |||
| 244 | res2 = connect.search(collection, query, fields=['int64']) |
||
| 245 | assert len(res2) == 1 |
||
| 246 | assert len(res2[0]) >= top_k |
||
| 247 | assert res2[0][0].id != ids[0] |
||
| 248 | if top_k > 1: |
||
| 249 | assert res2[0][0].id == res[0][1].id |
||
| 250 | assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64") |
||
| 251 | |||
| 252 | # TODO: |
||
| 253 | @pytest.mark.level(2) |
||
| 254 | def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
| 255 | ''' |
||
| 256 | target: test basic search function, all the search params is corrent, test all index params, and build |
||
| 257 | method: search with the given vectors, check the result |
||
| 258 | expected: the length of the result is top_k |
||
| 259 | ''' |
||
| 260 | top_k = get_top_k |
||
| 261 | nq = get_nq |
||
| 262 | |||
| 263 | index_type = get_simple_index["index_type"] |
||
| 264 | if index_type in skip_pq(): |
||
| 265 | pytest.skip("Skip PQ") |
||
| 266 | entities, ids = init_data(connect, collection) |
||
| 267 | connect.create_index(collection, field_name, get_simple_index) |
||
| 268 | search_param = get_search_param(index_type) |
||
| 269 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
| 270 | if top_k > max_top_k: |
||
| 271 | with pytest.raises(Exception) as e: |
||
| 272 | res = connect.search(collection, query) |
||
| 273 | else: |
||
| 274 | res = connect.search(collection, query) |
||
| 275 | assert len(res) == nq |
||
| 276 | assert len(res[0]) >= top_k |
||
| 277 | assert res[0]._distances[0] < epsilon |
||
| 278 | assert check_id_result(res[0], ids[0]) |
||
| 279 | |||
| 280 | def test_search_after_index_different_metric_type(self, connect, collection, get_simple_index): |
||
| 281 | ''' |
||
| 282 | target: test search with different metric_type |
||
| 283 | method: build index with L2, and search using IP |
||
| 284 | expected: search ok |
||
| 285 | ''' |
||
| 286 | search_metric_type = "IP" |
||
| 287 | index_type = get_simple_index["index_type"] |
||
| 288 | entities, ids = init_data(connect, collection) |
||
| 289 | connect.create_index(collection, field_name, get_simple_index) |
||
| 290 | search_param = get_search_param(index_type) |
||
| 291 | query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, metric_type=search_metric_type, |
||
| 292 | search_params=search_param) |
||
| 293 | res = connect.search(collection, query) |
||
| 294 | assert len(res) == nq |
||
| 295 | assert len(res[0]) == default_top_k |
||
| 296 | |||
| 297 | @pytest.mark.level(2) |
||
| 298 | def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
| 299 | ''' |
||
| 300 | target: test basic search function, all the search params is corrent, test all index params, and build |
||
| 301 | method: add vectors into collection, search with the given vectors, check the result |
||
| 302 | expected: the length of the result is top_k, search collection with partition tag return empty |
||
| 303 | ''' |
||
| 304 | top_k = get_top_k |
||
| 305 | nq = get_nq |
||
| 306 | |||
| 307 | index_type = get_simple_index["index_type"] |
||
| 308 | if index_type in skip_pq(): |
||
| 309 | pytest.skip("Skip PQ") |
||
| 310 | connect.create_partition(collection, default_tag) |
||
| 311 | entities, ids = init_data(connect, collection) |
||
| 312 | connect.create_index(collection, field_name, get_simple_index) |
||
| 313 | search_param = get_search_param(index_type) |
||
| 314 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
| 315 | if top_k > max_top_k: |
||
| 316 | with pytest.raises(Exception) as e: |
||
| 317 | res = connect.search(collection, query) |
||
| 318 | else: |
||
| 319 | res = connect.search(collection, query) |
||
| 320 | assert len(res) == nq |
||
| 321 | assert len(res[0]) >= top_k |
||
| 322 | assert res[0]._distances[0] < epsilon |
||
| 323 | assert check_id_result(res[0], ids[0]) |
||
| 324 | res = connect.search(collection, query, partition_tags=[default_tag]) |
||
| 325 | assert len(res) == nq |
||
| 326 | |||
| 327 | @pytest.mark.level(2) |
||
| 328 | def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
| 329 | ''' |
||
| 330 | target: test basic search function, all the search params is corrent, test all index params, and build |
||
| 331 | method: search with the given vectors, check the result |
||
| 332 | expected: the length of the result is top_k |
||
| 333 | ''' |
||
| 334 | top_k = get_top_k |
||
| 335 | nq = get_nq |
||
| 336 | |||
| 337 | index_type = get_simple_index["index_type"] |
||
| 338 | if index_type in skip_pq(): |
||
| 339 | pytest.skip("Skip PQ") |
||
| 340 | connect.create_partition(collection, default_tag) |
||
| 341 | entities, ids = init_data(connect, collection, partition_tags=default_tag) |
||
| 342 | connect.create_index(collection, field_name, get_simple_index) |
||
| 343 | search_param = get_search_param(index_type) |
||
| 344 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
| 345 | for tags in [[default_tag], [default_tag, "new_tag"]]: |
||
| 346 | if top_k > max_top_k: |
||
| 347 | with pytest.raises(Exception) as e: |
||
| 348 | res = connect.search(collection, query, partition_tags=tags) |
||
| 349 | else: |
||
| 350 | res = connect.search(collection, query, partition_tags=tags) |
||
| 351 | assert len(res) == nq |
||
| 352 | assert len(res[0]) >= top_k |
||
| 353 | assert res[0]._distances[0] < epsilon |
||
| 354 | assert check_id_result(res[0], ids[0]) |
||
| 355 | |||
| 356 | @pytest.mark.level(2) |
||
| 357 | def test_search_index_partition_C(self, connect, collection, get_top_k, get_nq): |
||
| 358 | ''' |
||
| 359 | target: test basic search function, all the search params is corrent, test all index params, and build |
||
| 360 | method: search with the given vectors and tag (tag name not existed in collection), check the result |
||
| 361 | expected: error raised |
||
| 362 | ''' |
||
| 363 | top_k = get_top_k |
||
| 364 | nq = get_nq |
||
| 365 | entities, ids = init_data(connect, collection) |
||
| 366 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
| 367 | if top_k > max_top_k: |
||
| 368 | with pytest.raises(Exception) as e: |
||
| 369 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
| 370 | else: |
||
| 371 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
| 372 | assert len(res) == nq |
||
| 373 | assert len(res[0]) == 0 |
||
| 374 | |||
| 375 | @pytest.mark.level(2) |
||
| 376 | def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k): |
||
| 377 | ''' |
||
| 378 | target: test basic search function, all the search params is corrent, test all index params, and build |
||
| 379 | method: search collection with the given vectors and tags, check the result |
||
| 380 | expected: the length of the result is top_k |
||
| 381 | ''' |
||
| 382 | top_k = get_top_k |
||
| 383 | nq = 2 |
||
| 384 | new_tag = "new_tag" |
||
| 385 | index_type = get_simple_index["index_type"] |
||
| 386 | if index_type in skip_pq(): |
||
| 387 | pytest.skip("Skip PQ") |
||
| 388 | connect.create_partition(collection, default_tag) |
||
| 389 | connect.create_partition(collection, new_tag) |
||
| 390 | entities, ids = init_data(connect, collection, partition_tags=default_tag) |
||
| 391 | new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) |
||
| 392 | connect.create_index(collection, field_name, get_simple_index) |
||
| 393 | search_param = get_search_param(index_type) |
||
| 394 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
| 395 | if top_k > max_top_k: |
||
| 396 | with pytest.raises(Exception) as e: |
||
| 397 | res = connect.search(collection, query) |
||
| 398 | else: |
||
| 399 | res = connect.search(collection, query) |
||
| 400 | assert check_id_result(res[0], ids[0]) |
||
| 401 | assert not check_id_result(res[1], new_ids[0]) |
||
| 402 | assert res[0]._distances[0] < epsilon |
||
| 403 | assert res[1]._distances[0] < epsilon |
||
| 404 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
| 405 | assert res[0]._distances[0] > epsilon |
||
| 406 | assert res[1]._distances[0] > epsilon |
||
| 407 | |||
| 408 | @pytest.mark.level(2) |
||
| 409 | def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k): |
||
| 410 | ''' |
||
| 411 | target: test basic search function, all the search params is corrent, test all index params, and build |
||
| 412 | method: search collection with the given vectors and tags, check the result |
||
| 413 | expected: the length of the result is top_k |
||
| 414 | ''' |
||
| 415 | top_k = get_top_k |
||
| 416 | nq = 2 |
||
| 417 | tag = "tag" |
||
| 418 | new_tag = "new_tag" |
||
| 419 | index_type = get_simple_index["index_type"] |
||
| 420 | if index_type in skip_pq(): |
||
| 421 | pytest.skip("Skip PQ") |
||
| 422 | connect.create_partition(collection, tag) |
||
| 423 | connect.create_partition(collection, new_tag) |
||
| 424 | entities, ids = init_data(connect, collection, partition_tags=tag) |
||
| 425 | new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) |
||
| 426 | connect.create_index(collection, field_name, get_simple_index) |
||
| 427 | search_param = get_search_param(index_type) |
||
| 428 | query, vecs = gen_query_vectors(field_name, new_entities, top_k, nq, search_params=search_param) |
||
| 429 | if top_k > max_top_k: |
||
| 430 | with pytest.raises(Exception) as e: |
||
| 431 | res = connect.search(collection, query) |
||
| 432 | else: |
||
| 433 | res = connect.search(collection, query, partition_tags=["(.*)tag"]) |
||
| 434 | assert not check_id_result(res[0], ids[0]) |
||
| 435 | assert res[0]._distances[0] < epsilon |
||
| 436 | assert res[1]._distances[0] < epsilon |
||
| 437 | res = connect.search(collection, query, partition_tags=["new(.*)"]) |
||
| 438 | assert res[0]._distances[0] < epsilon |
||
| 439 | assert res[1]._distances[0] < epsilon |
||
| 440 | |||
| 441 | # |
||
| 442 | # test for ip metric |
||
| 443 | # |
||
| 444 | @pytest.mark.level(2) |
||
| 445 | def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
| 446 | ''' |
||
| 447 | target: test basic search function, all the search params is corrent, change top-k value |
||
| 448 | method: search with the given vectors, check the result |
||
| 449 | expected: the length of the result is top_k |
||
| 450 | ''' |
||
| 451 | top_k = get_top_k |
||
| 452 | nq = get_nq |
||
| 453 | entities, ids = init_data(connect, collection) |
||
| 454 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP") |
||
| 455 | if top_k <= max_top_k: |
||
| 456 | res = connect.search(collection, query) |
||
| 457 | assert len(res[0]) == top_k |
||
| 458 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
| 459 | assert check_id_result(res[0], ids[0]) |
||
| 460 | else: |
||
| 461 | with pytest.raises(Exception) as e: |
||
| 462 | res = connect.search(collection, query) |
||
| 463 | |||
| 464 | @pytest.mark.level(2) |
||
| 465 | def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
| 466 | ''' |
||
| 467 | target: test basic search function, all the search params is corrent, test all index params, and build |
||
| 468 | method: search with the given vectors, check the result |
||
| 469 | expected: the length of the result is top_k |
||
| 470 | ''' |
||
| 471 | top_k = get_top_k |
||
| 472 | nq = get_nq |
||
| 473 | |||
| 474 | index_type = get_simple_index["index_type"] |
||
| 475 | if index_type in skip_pq(): |
||
| 476 | pytest.skip("Skip PQ") |
||
| 477 | entities, ids = init_data(connect, collection) |
||
| 478 | get_simple_index["metric_type"] = "IP" |
||
| 479 | connect.create_index(collection, field_name, get_simple_index) |
||
| 480 | search_param = get_search_param(index_type) |
||
| 481 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param) |
||
| 482 | if top_k > max_top_k: |
||
| 483 | with pytest.raises(Exception) as e: |
||
| 484 | res = connect.search(collection, query) |
||
| 485 | else: |
||
| 486 | res = connect.search(collection, query) |
||
| 487 | assert len(res) == nq |
||
| 488 | assert len(res[0]) >= top_k |
||
| 489 | assert check_id_result(res[0], ids[0]) |
||
| 490 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
| 491 | |||
| 492 | @pytest.mark.level(2) |
||
| 493 | def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
| 494 | ''' |
||
| 495 | target: test basic search function, all the search params is corrent, test all index params, and build |
||
| 496 | method: add vectors into collection, search with the given vectors, check the result |
||
| 497 | expected: the length of the result is top_k, search collection with partition tag return empty |
||
| 498 | ''' |
||
| 499 | top_k = get_top_k |
||
| 500 | nq = get_nq |
||
| 501 | metric_type = "IP" |
||
| 502 | index_type = get_simple_index["index_type"] |
||
| 503 | if index_type in skip_pq(): |
||
| 504 | pytest.skip("Skip PQ") |
||
| 505 | connect.create_partition(collection, default_tag) |
||
| 506 | entities, ids = init_data(connect, collection) |
||
| 507 | get_simple_index["metric_type"] = metric_type |
||
| 508 | connect.create_index(collection, field_name, get_simple_index) |
||
| 509 | search_param = get_search_param(index_type) |
||
| 510 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=metric_type, |
||
| 511 | search_params=search_param) |
||
| 512 | if top_k > max_top_k: |
||
| 513 | with pytest.raises(Exception) as e: |
||
| 514 | res = connect.search(collection, query) |
||
| 515 | else: |
||
| 516 | res = connect.search(collection, query) |
||
| 517 | assert len(res) == nq |
||
| 518 | assert len(res[0]) >= top_k |
||
| 519 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
| 520 | assert check_id_result(res[0], ids[0]) |
||
| 521 | res = connect.search(collection, query, partition_tags=[default_tag]) |
||
| 522 | assert len(res) == nq |
||
| 523 | |||
| 524 | @pytest.mark.level(2) |
||
| 525 | def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k): |
||
| 526 | ''' |
||
| 527 | target: test basic search function, all the search params is corrent, test all index params, and build |
||
| 528 | method: search collection with the given vectors and tags, check the result |
||
| 529 | expected: the length of the result is top_k |
||
| 530 | ''' |
||
| 531 | top_k = get_top_k |
||
| 532 | nq = 2 |
||
| 533 | metric_type = "IP" |
||
| 534 | new_tag = "new_tag" |
||
| 535 | index_type = get_simple_index["index_type"] |
||
| 536 | if index_type in skip_pq(): |
||
| 537 | pytest.skip("Skip PQ") |
||
| 538 | connect.create_partition(collection, default_tag) |
||
| 539 | connect.create_partition(collection, new_tag) |
||
| 540 | entities, ids = init_data(connect, collection, partition_tags=default_tag) |
||
| 541 | new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) |
||
| 542 | get_simple_index["metric_type"] = metric_type |
||
| 543 | connect.create_index(collection, field_name, get_simple_index) |
||
| 544 | search_param = get_search_param(index_type) |
||
| 545 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param) |
||
| 546 | if top_k > max_top_k: |
||
| 547 | with pytest.raises(Exception) as e: |
||
| 548 | res = connect.search(collection, query) |
||
| 549 | else: |
||
| 550 | res = connect.search(collection, query) |
||
| 551 | assert check_id_result(res[0], ids[0]) |
||
| 552 | assert not check_id_result(res[1], new_ids[0]) |
||
| 553 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
| 554 | assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0]) |
||
| 555 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
| 556 | assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0]) |
||
| 557 | # TODO: |
||
| 558 | # assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0]) |
||
| 559 | |||
| 560 | @pytest.mark.level(2) |
||
| 561 | def test_search_without_connect(self, dis_connect, collection): |
||
| 562 | ''' |
||
| 563 | target: test search vectors without connection |
||
| 564 | method: use dis connected instance, call search method and check if search successfully |
||
| 565 | expected: raise exception |
||
| 566 | ''' |
||
| 567 | with pytest.raises(Exception) as e: |
||
| 568 | res = dis_connect.search(collection, default_query) |
||
| 569 | |||
| 570 | def test_search_collection_name_not_existed(self, connect): |
||
| 571 | ''' |
||
| 572 | target: search collection not existed |
||
| 573 | method: search with the random collection_name, which is not in db |
||
| 574 | expected: status not ok |
||
| 575 | ''' |
||
| 576 | collection_name = gen_unique_str(uid) |
||
| 577 | with pytest.raises(Exception) as e: |
||
| 578 | res = connect.search(collection_name, default_query) |
||
| 579 | |||
| 580 | def test_search_distance_l2(self, connect, collection): |
||
| 581 | ''' |
||
| 582 | target: search collection, and check the result: distance |
||
| 583 | method: compare the return distance value with value computed with Euclidean |
||
| 584 | expected: the return distance equals to the computed value |
||
| 585 | ''' |
||
| 586 | nq = 2 |
||
| 587 | search_param = {"nprobe": 1} |
||
| 588 | entities, ids = init_data(connect, collection, nb=nq) |
||
| 589 | query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, search_params=search_param) |
||
| 590 | inside_query, inside_vecs = gen_query_vectors(field_name, entities, default_top_k, nq, search_params=search_param) |
||
| 591 | distance_0 = l2(vecs[0], inside_vecs[0]) |
||
| 592 | distance_1 = l2(vecs[0], inside_vecs[1]) |
||
| 593 | res = connect.search(collection, query) |
||
| 594 | assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) |
||
| 595 | |||
| 596 | View Code Duplication | def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index): |
|
| 597 | ''' |
||
| 598 | target: search collection, and check the result: distance |
||
| 599 | method: compare the return distance value with value computed with Inner product |
||
| 600 | expected: the return distance equals to the computed value |
||
| 601 | ''' |
||
| 602 | index_type = get_simple_index["index_type"] |
||
| 603 | nq = 2 |
||
| 604 | entities, ids = init_data(connect, id_collection, auto_id=False) |
||
| 605 | connect.create_index(id_collection, field_name, get_simple_index) |
||
| 606 | search_param = get_search_param(index_type) |
||
| 607 | query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, search_params=search_param) |
||
| 608 | inside_vecs = entities[-1]["values"] |
||
| 609 | min_distance = 1.0 |
||
| 610 | min_id = None |
||
| 611 | for i in range(default_nb): |
||
| 612 | tmp_dis = l2(vecs[0], inside_vecs[i]) |
||
| 613 | if min_distance > tmp_dis: |
||
| 614 | min_distance = tmp_dis |
||
| 615 | min_id = ids[i] |
||
| 616 | res = connect.search(id_collection, query) |
||
| 617 | tmp_epsilon = epsilon |
||
| 618 | check_id_result(res[0], min_id) |
||
| 619 | # if index_type in ["ANNOY", "IVF_PQ"]: |
||
| 620 | # tmp_epsilon = 0.1 |
||
| 621 | # TODO: |
||
| 622 | # assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= tmp_epsilon |
||
| 623 | |||
| 624 | @pytest.mark.level(2) |
||
| 625 | def test_search_distance_ip(self, connect, collection): |
||
| 626 | ''' |
||
| 627 | target: search collection, and check the result: distance |
||
| 628 | method: compare the return distance value with value computed with Inner product |
||
| 629 | expected: the return distance equals to the computed value |
||
| 630 | ''' |
||
| 631 | nq = 2 |
||
| 632 | metirc_type = "IP" |
||
| 633 | search_param = {"nprobe": 1} |
||
| 634 | entities, ids = init_data(connect, collection, nb=nq) |
||
| 635 | query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, metric_type=metirc_type, |
||
| 636 | search_params=search_param) |
||
| 637 | inside_query, inside_vecs = gen_query_vectors(field_name, entities, default_top_k, nq, search_params=search_param) |
||
| 638 | distance_0 = ip(vecs[0], inside_vecs[0]) |
||
| 639 | distance_1 = ip(vecs[0], inside_vecs[1]) |
||
| 640 | res = connect.search(collection, query) |
||
| 641 | assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon |
||
| 642 | |||
| 643 | View Code Duplication | def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index): |
|
| 644 | ''' |
||
| 645 | target: search collection, and check the result: distance |
||
| 646 | method: compare the return distance value with value computed with Inner product |
||
| 647 | expected: the return distance equals to the computed value |
||
| 648 | ''' |
||
| 649 | index_type = get_simple_index["index_type"] |
||
| 650 | nq = 2 |
||
| 651 | metirc_type = "IP" |
||
| 652 | entities, ids = init_data(connect, id_collection, auto_id=False) |
||
| 653 | get_simple_index["metric_type"] = metirc_type |
||
| 654 | connect.create_index(id_collection, field_name, get_simple_index) |
||
| 655 | search_param = get_search_param(index_type) |
||
| 656 | query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, metric_type=metirc_type, |
||
| 657 | search_params=search_param) |
||
| 658 | inside_vecs = entities[-1]["values"] |
||
| 659 | max_distance = 0 |
||
| 660 | max_id = None |
||
| 661 | for i in range(default_nb): |
||
| 662 | tmp_dis = ip(vecs[0], inside_vecs[i]) |
||
| 663 | if max_distance < tmp_dis: |
||
| 664 | max_distance = tmp_dis |
||
| 665 | max_id = ids[i] |
||
| 666 | res = connect.search(id_collection, query) |
||
| 667 | tmp_epsilon = epsilon |
||
| 668 | check_id_result(res[0], max_id) |
||
| 669 | # if index_type in ["ANNOY", "IVF_PQ"]: |
||
| 670 | # tmp_epsilon = 0.1 |
||
| 671 | # TODO: |
||
| 672 | # assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon |
||
| 673 | |||
| 674 | View Code Duplication | def test_search_distance_jaccard_flat_index(self, connect, binary_collection): |
|
| 675 | ''' |
||
| 676 | target: search binary_collection, and check the result: distance |
||
| 677 | method: compare the return distance value with value computed with L2 |
||
| 678 | expected: the return distance equals to the computed value |
||
| 679 | ''' |
||
| 680 | nq = 1 |
||
| 681 | int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) |
||
| 682 | query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
| 683 | distance_0 = jaccard(query_int_vectors[0], int_vectors[0]) |
||
| 684 | distance_1 = jaccard(query_int_vectors[0], int_vectors[1]) |
||
| 685 | query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="JACCARD") |
||
| 686 | res = connect.search(binary_collection, query) |
||
| 687 | assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon |
||
| 688 | |||
| 689 | @pytest.mark.level(2) |
||
| 690 | def test_search_distance_jaccard_flat_index_L2(self, connect, binary_collection): |
||
| 691 | ''' |
||
| 692 | target: search binary_collection, and check the result: distance |
||
| 693 | method: compare the return distance value with value computed with L2 |
||
| 694 | expected: the return distance equals to the computed value |
||
| 695 | ''' |
||
| 696 | nq = 1 |
||
| 697 | int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) |
||
| 698 | query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
| 699 | distance_0 = jaccard(query_int_vectors[0], int_vectors[0]) |
||
| 700 | distance_1 = jaccard(query_int_vectors[0], int_vectors[1]) |
||
| 701 | query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="L2") |
||
| 702 | with pytest.raises(Exception) as e: |
||
| 703 | res = connect.search(binary_collection, query) |
||
| 704 | |||
| 705 | View Code Duplication | @pytest.mark.level(2) |
|
| 706 | def test_search_distance_hamming_flat_index(self, connect, binary_collection): |
||
| 707 | ''' |
||
| 708 | target: search binary_collection, and check the result: distance |
||
| 709 | method: compare the return distance value with value computed with Inner product |
||
| 710 | expected: the return distance equals to the computed value |
||
| 711 | ''' |
||
| 712 | nq = 1 |
||
| 713 | int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) |
||
| 714 | query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
| 715 | distance_0 = hamming(query_int_vectors[0], int_vectors[0]) |
||
| 716 | distance_1 = hamming(query_int_vectors[0], int_vectors[1]) |
||
| 717 | query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="HAMMING") |
||
| 718 | res = connect.search(binary_collection, query) |
||
| 719 | assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon |
||
| 720 | |||
| 721 | View Code Duplication | @pytest.mark.level(2) |
|
| 722 | def test_search_distance_substructure_flat_index(self, connect, binary_collection): |
||
| 723 | ''' |
||
| 724 | target: search binary_collection, and check the result: distance |
||
| 725 | method: compare the return distance value with value computed with Inner product |
||
| 726 | expected: the return distance equals to the computed value |
||
| 727 | ''' |
||
| 728 | nq = 1 |
||
| 729 | int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) |
||
| 730 | query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
| 731 | distance_0 = substructure(query_int_vectors[0], int_vectors[0]) |
||
| 732 | distance_1 = substructure(query_int_vectors[0], int_vectors[1]) |
||
| 733 | query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="SUBSTRUCTURE") |
||
| 734 | res = connect.search(binary_collection, query) |
||
| 735 | assert len(res[0]) == 0 |
||
| 736 | |||
| 737 | @pytest.mark.level(2) |
||
| 738 | def test_search_distance_substructure_flat_index_B(self, connect, binary_collection): |
||
| 739 | ''' |
||
| 740 | target: search binary_collection, and check the result: distance |
||
| 741 | method: compare the return distance value with value computed with SUB |
||
| 742 | expected: the return distance equals to the computed value |
||
| 743 | ''' |
||
| 744 | top_k = 3 |
||
| 745 | int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) |
||
| 746 | query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2) |
||
| 747 | query, vecs = gen_query_vectors(binary_field_name, entities, top_k, nq, metric_type="SUBSTRUCTURE", replace_vecs=query_vecs) |
||
| 748 | res = connect.search(binary_collection, query) |
||
| 749 | assert res[0][0].distance <= epsilon |
||
| 750 | assert res[0][0].id == ids[0] |
||
| 751 | assert res[1][0].distance <= epsilon |
||
| 752 | assert res[1][0].id == ids[1] |
||
| 753 | |||
| 754 | View Code Duplication | @pytest.mark.level(2) |
|
| 755 | def test_search_distance_superstructure_flat_index(self, connect, binary_collection): |
||
| 756 | ''' |
||
| 757 | target: search binary_collection, and check the result: distance |
||
| 758 | method: compare the return distance value with value computed with Inner product |
||
| 759 | expected: the return distance equals to the computed value |
||
| 760 | ''' |
||
| 761 | nq = 1 |
||
| 762 | int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) |
||
| 763 | query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
| 764 | distance_0 = superstructure(query_int_vectors[0], int_vectors[0]) |
||
| 765 | distance_1 = superstructure(query_int_vectors[0], int_vectors[1]) |
||
| 766 | query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="SUPERSTRUCTURE") |
||
| 767 | res = connect.search(binary_collection, query) |
||
| 768 | assert len(res[0]) == 0 |
||
| 769 | |||
| 770 | @pytest.mark.level(2) |
||
| 771 | def test_search_distance_superstructure_flat_index_B(self, connect, binary_collection): |
||
| 772 | ''' |
||
| 773 | target: search binary_collection, and check the result: distance |
||
| 774 | method: compare the return distance value with value computed with SUPER |
||
| 775 | expected: the return distance equals to the computed value |
||
| 776 | ''' |
||
| 777 | top_k = 3 |
||
| 778 | int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) |
||
| 779 | query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2) |
||
| 780 | query, vecs = gen_query_vectors(binary_field_name, entities, top_k, nq, metric_type="SUPERSTRUCTURE", replace_vecs=query_vecs) |
||
| 781 | res = connect.search(binary_collection, query) |
||
| 782 | assert len(res[0]) == 2 |
||
| 783 | assert len(res[1]) == 2 |
||
| 784 | assert res[0][0].id in ids |
||
| 785 | assert res[0][0].distance <= epsilon |
||
| 786 | assert res[1][0].id in ids |
||
| 787 | assert res[1][0].distance <= epsilon |
||
| 788 | |||
| 789 | View Code Duplication | @pytest.mark.level(2) |
|
| 790 | def test_search_distance_tanimoto_flat_index(self, connect, binary_collection): |
||
| 791 | ''' |
||
| 792 | target: search binary_collection, and check the result: distance |
||
| 793 | method: compare the return distance value with value computed with Inner product |
||
| 794 | expected: the return distance equals to the computed value |
||
| 795 | ''' |
||
| 796 | nq = 1 |
||
| 797 | int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) |
||
| 798 | query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
| 799 | distance_0 = tanimoto(query_int_vectors[0], int_vectors[0]) |
||
| 800 | distance_1 = tanimoto(query_int_vectors[0], int_vectors[1]) |
||
| 801 | query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="TANIMOTO") |
||
| 802 | res = connect.search(binary_collection, query) |
||
| 803 | assert abs(res[0][0].distance - min(distance_0, distance_1)) <= epsilon |
||
| 804 | |||
| 805 | @pytest.mark.level(2) |
||
| 806 | @pytest.mark.timeout(30) |
||
| 807 | def test_search_concurrent_multithreads(self, connect, args): |
||
| 808 | ''' |
||
| 809 | target: test concurrent search with multiprocessess |
||
| 810 | method: search with 10 processes, each process uses dependent connection |
||
| 811 | expected: status ok and the returned vectors should be query_records |
||
| 812 | ''' |
||
| 813 | nb = 100 |
||
| 814 | top_k = 10 |
||
| 815 | threads_num = 4 |
||
| 816 | threads = [] |
||
| 817 | collection = gen_unique_str(uid) |
||
| 818 | uri = "tcp://%s:%s" % (args["ip"], args["port"]) |
||
| 819 | # create collection |
||
| 820 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
| 821 | milvus.create_collection(collection, default_fields) |
||
| 822 | entities, ids = init_data(milvus, collection) |
||
| 823 | |||
| 824 | def search(milvus): |
||
| 825 | res = connect.search(collection, default_query) |
||
| 826 | assert len(res) == 1 |
||
| 827 | assert res[0]._entities[0].id in ids |
||
| 828 | assert res[0]._distances[0] < epsilon |
||
| 829 | |||
| 830 | for i in range(threads_num): |
||
| 831 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
| 832 | t = threading.Thread(target=search, args=(milvus,)) |
||
| 833 | threads.append(t) |
||
| 834 | t.start() |
||
| 835 | time.sleep(0.2) |
||
| 836 | for t in threads: |
||
| 837 | t.join() |
||
| 838 | |||
| 839 | @pytest.mark.level(2) |
||
| 840 | @pytest.mark.timeout(30) |
||
| 841 | def test_search_concurrent_multithreads_single_connection(self, connect, args): |
||
| 842 | ''' |
||
| 843 | target: test concurrent search with multiprocessess |
||
| 844 | method: search with 10 processes, each process uses dependent connection |
||
| 845 | expected: status ok and the returned vectors should be query_records |
||
| 846 | ''' |
||
| 847 | nb = 100 |
||
| 848 | top_k = 10 |
||
| 849 | threads_num = 4 |
||
| 850 | threads = [] |
||
| 851 | collection = gen_unique_str(uid) |
||
| 852 | uri = "tcp://%s:%s" % (args["ip"], args["port"]) |
||
| 853 | # create collection |
||
| 854 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
| 855 | milvus.create_collection(collection, default_fields) |
||
| 856 | entities, ids = init_data(milvus, collection) |
||
| 857 | |||
| 858 | def search(milvus): |
||
| 859 | res = connect.search(collection, default_query) |
||
| 860 | assert len(res) == 1 |
||
| 861 | assert res[0]._entities[0].id in ids |
||
| 862 | assert res[0]._distances[0] < epsilon |
||
| 863 | |||
| 864 | for i in range(threads_num): |
||
| 865 | t = threading.Thread(target=search, args=(milvus,)) |
||
| 866 | threads.append(t) |
||
| 867 | t.start() |
||
| 868 | time.sleep(0.2) |
||
| 869 | for t in threads: |
||
| 870 | t.join() |
||
| 871 | |||
| 872 | @pytest.mark.level(2) |
||
| 873 | def test_search_multi_collections(self, connect, args): |
||
| 874 | ''' |
||
| 875 | target: test search multi collections of L2 |
||
| 876 | method: add vectors into 10 collections, and search |
||
| 877 | expected: search status ok, the length of result |
||
| 878 | ''' |
||
| 879 | num = 10 |
||
| 880 | top_k = 10 |
||
| 881 | nq = 20 |
||
| 882 | for i in range(num): |
||
| 883 | collection = gen_unique_str(uid + str(i)) |
||
| 884 | connect.create_collection(collection, default_fields) |
||
| 885 | entities, ids = init_data(connect, collection) |
||
| 886 | assert len(ids) == default_nb |
||
| 887 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
| 888 | res = connect.search(collection, query) |
||
| 889 | assert len(res) == nq |
||
| 890 | for i in range(nq): |
||
| 891 | assert check_id_result(res[i], ids[i]) |
||
| 892 | assert res[i]._distances[0] < epsilon |
||
| 893 | assert res[i]._distances[1] > epsilon |
||
| 894 | |||
| 895 | |||
| 896 | class TestSearchDSL(object): |
||
| 897 | """ |
||
| 898 | ****************************************************************** |
||
| 899 | # The following cases are used to build invalid query expr |
||
| 900 | ****************************************************************** |
||
| 901 | """ |
||
| 902 | |||
| 903 | def test_query_no_must(self, connect, collection): |
||
| 904 | ''' |
||
| 905 | method: build query without must expr |
||
| 906 | expected: error raised |
||
| 907 | ''' |
||
| 908 | # entities, ids = init_data(connect, collection) |
||
| 909 | query = update_query_expr(default_query, keep_old=False) |
||
| 910 | with pytest.raises(Exception) as e: |
||
| 911 | res = connect.search(collection, query) |
||
| 912 | |||
| 913 | def test_query_no_vector_term_only(self, connect, collection): |
||
| 914 | ''' |
||
| 915 | method: build query without vector only term |
||
| 916 | expected: error raised |
||
| 917 | ''' |
||
| 918 | # entities, ids = init_data(connect, collection) |
||
| 919 | expr = { |
||
| 920 | "must": [gen_default_term_expr] |
||
| 921 | } |
||
| 922 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
| 923 | with pytest.raises(Exception) as e: |
||
| 924 | res = connect.search(collection, query) |
||
| 925 | |||
| 926 | def test_query_no_vector_range_only(self, connect, collection): |
||
| 927 | ''' |
||
| 928 | method: build query without vector only range |
||
| 929 | expected: error raised |
||
| 930 | ''' |
||
| 931 | # entities, ids = init_data(connect, collection) |
||
| 932 | expr = { |
||
| 933 | "must": [gen_default_range_expr] |
||
| 934 | } |
||
| 935 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
| 936 | with pytest.raises(Exception) as e: |
||
| 937 | res = connect.search(collection, query) |
||
| 938 | |||
| 939 | def test_query_vector_only(self, connect, collection): |
||
| 940 | entities, ids = init_data(connect, collection) |
||
| 941 | res = connect.search(collection, default_query) |
||
| 942 | assert len(res) == nq |
||
| 943 | assert len(res[0]) == default_top_k |
||
| 944 | |||
| 945 | def test_query_wrong_format(self, connect, collection): |
||
| 946 | ''' |
||
| 947 | method: build query without must expr, with wrong expr name |
||
| 948 | expected: error raised |
||
| 949 | ''' |
||
| 950 | # entities, ids = init_data(connect, collection) |
||
| 951 | expr = { |
||
| 952 | "must1": [gen_default_term_expr] |
||
| 953 | } |
||
| 954 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
| 955 | with pytest.raises(Exception) as e: |
||
| 956 | res = connect.search(collection, query) |
||
| 957 | |||
| 958 | def test_query_empty(self, connect, collection): |
||
| 959 | ''' |
||
| 960 | method: search with empty query |
||
| 961 | expected: error raised |
||
| 962 | ''' |
||
| 963 | query = {} |
||
| 964 | with pytest.raises(Exception) as e: |
||
| 965 | res = connect.search(collection, query) |
||
| 966 | |||
| 967 | """ |
||
| 968 | ****************************************************************** |
||
| 969 | # The following cases are used to build valid query expr |
||
| 970 | ****************************************************************** |
||
| 971 | """ |
||
| 972 | |||
| 973 | @pytest.mark.level(2) |
||
| 974 | def test_query_term_value_not_in(self, connect, collection): |
||
| 975 | ''' |
||
| 976 | method: build query with vector and term expr, with no term can be filtered |
||
| 977 | expected: filter pass |
||
| 978 | ''' |
||
| 979 | entities, ids = init_data(connect, collection) |
||
| 980 | expr = { |
||
| 981 | "must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[100000])]} |
||
| 982 | query = update_query_expr(default_query, expr=expr) |
||
| 983 | res = connect.search(collection, query) |
||
| 984 | assert len(res) == nq |
||
| 985 | assert len(res[0]) == 0 |
||
| 986 | # TODO: |
||
| 987 | |||
| 988 | # TODO: |
||
| 989 | @pytest.mark.level(2) |
||
| 990 | def test_query_term_value_all_in(self, connect, collection): |
||
| 991 | ''' |
||
| 992 | method: build query with vector and term expr, with all term can be filtered |
||
| 993 | expected: filter pass |
||
| 994 | ''' |
||
| 995 | entities, ids = init_data(connect, collection) |
||
| 996 | expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1])]} |
||
| 997 | query = update_query_expr(default_query, expr=expr) |
||
| 998 | res = connect.search(collection, query) |
||
| 999 | assert len(res) == nq |
||
| 1000 | assert len(res[0]) == 1 |
||
| 1001 | # TODO: |
||
| 1002 | |||
| 1003 | # TODO: |
||
| 1004 | View Code Duplication | @pytest.mark.level(2) |
|
| 1005 | def test_query_term_values_not_in(self, connect, collection): |
||
| 1006 | ''' |
||
| 1007 | method: build query with vector and term expr, with no term can be filtered |
||
| 1008 | expected: filter pass |
||
| 1009 | ''' |
||
| 1010 | entities, ids = init_data(connect, collection) |
||
| 1011 | expr = {"must": [gen_default_vector_expr(default_query), |
||
| 1012 | gen_default_term_expr(values=[i for i in range(100000, 100010)])]} |
||
| 1013 | query = update_query_expr(default_query, expr=expr) |
||
| 1014 | res = connect.search(collection, query) |
||
| 1015 | assert len(res) == nq |
||
| 1016 | assert len(res[0]) == 0 |
||
| 1017 | # TODO: |
||
| 1018 | |||
| 1019 | def test_query_term_values_all_in(self, connect, collection): |
||
| 1020 | ''' |
||
| 1021 | method: build query with vector and term expr, with all term can be filtered |
||
| 1022 | expected: filter pass |
||
| 1023 | ''' |
||
| 1024 | entities, ids = init_data(connect, collection) |
||
| 1025 | expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr()]} |
||
| 1026 | query = update_query_expr(default_query, expr=expr) |
||
| 1027 | res = connect.search(collection, query) |
||
| 1028 | assert len(res) == nq |
||
| 1029 | assert len(res[0]) == default_top_k |
||
| 1030 | # TODO: |
||
| 1031 | |||
| 1032 | View Code Duplication | def test_query_term_values_parts_in(self, connect, collection): |
|
| 1033 | ''' |
||
| 1034 | method: build query with vector and term expr, with parts of term can be filtered |
||
| 1035 | expected: filter pass |
||
| 1036 | ''' |
||
| 1037 | entities, ids = init_data(connect, collection) |
||
| 1038 | expr = {"must": [gen_default_vector_expr(default_query), |
||
| 1039 | gen_default_term_expr(values=[i for i in range(default_nb // 2, default_nb + default_nb // 2)])]} |
||
| 1040 | query = update_query_expr(default_query, expr=expr) |
||
| 1041 | res = connect.search(collection, query) |
||
| 1042 | assert len(res) == nq |
||
| 1043 | assert len(res[0]) == default_top_k |
||
| 1044 | # TODO: |
||
| 1045 | |||
| 1046 | # TODO: |
||
| 1047 | View Code Duplication | @pytest.mark.level(2) |
|
| 1048 | def test_query_term_values_repeat(self, connect, collection): |
||
| 1049 | ''' |
||
| 1050 | method: build query with vector and term expr, with the same values |
||
| 1051 | expected: filter pass |
||
| 1052 | ''' |
||
| 1053 | entities, ids = init_data(connect, collection) |
||
| 1054 | expr = { |
||
| 1055 | "must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1 for i in range(1, default_nb)])]} |
||
| 1056 | query = update_query_expr(default_query, expr=expr) |
||
| 1057 | res = connect.search(collection, query) |
||
| 1058 | assert len(res) == nq |
||
| 1059 | assert len(res[0]) == 1 |
||
| 1060 | # TODO: |
||
| 1061 | |||
| 1062 | def test_query_term_value_empty(self, connect, collection): |
||
| 1063 | ''' |
||
| 1064 | method: build query with term value empty |
||
| 1065 | expected: return null |
||
| 1066 | ''' |
||
| 1067 | expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[])]} |
||
| 1068 | query = update_query_expr(default_query, expr=expr) |
||
| 1069 | res = connect.search(collection, query) |
||
| 1070 | assert len(res) == nq |
||
| 1071 | assert len(res[0]) == 0 |
||
| 1072 | |||
| 1073 | """ |
||
| 1074 | ****************************************************************** |
||
| 1075 | # The following cases are used to build invalid term query expr |
||
| 1076 | ****************************************************************** |
||
| 1077 | """ |
||
| 1078 | |||
| 1079 | # TODO |
||
| 1080 | @pytest.mark.level(2) |
||
| 1081 | def test_query_term_key_error(self, connect, collection): |
||
| 1082 | ''' |
||
| 1083 | method: build query with term key error |
||
| 1084 | expected: Exception raised |
||
| 1085 | ''' |
||
| 1086 | expr = {"must": [gen_default_vector_expr(default_query), |
||
| 1087 | gen_default_term_expr(keyword="terrm", values=[i for i in range(default_nb // 2)])]} |
||
| 1088 | query = update_query_expr(default_query, expr=expr) |
||
| 1089 | with pytest.raises(Exception) as e: |
||
| 1090 | res = connect.search(collection, query) |
||
| 1091 | |||
| 1092 | @pytest.fixture( |
||
| 1093 | scope="function", |
||
| 1094 | params=gen_invalid_term() |
||
| 1095 | ) |
||
| 1096 | def get_invalid_term(self, request): |
||
| 1097 | return request.param |
||
| 1098 | |||
| 1099 | @pytest.mark.level(2) |
||
| 1100 | def test_query_term_wrong_format(self, connect, collection, get_invalid_term): |
||
| 1101 | ''' |
||
| 1102 | method: build query with wrong format term |
||
| 1103 | expected: Exception raised |
||
| 1104 | ''' |
||
| 1105 | entities, ids = init_data(connect, collection) |
||
| 1106 | term = get_invalid_term |
||
| 1107 | expr = {"must": [gen_default_vector_expr(default_query), term]} |
||
| 1108 | query = update_query_expr(default_query, expr=expr) |
||
| 1109 | with pytest.raises(Exception) as e: |
||
| 1110 | res = connect.search(collection, query) |
||
| 1111 | |||
| 1112 | # TODO |
||
| 1113 | @pytest.mark.level(2) |
||
| 1114 | def test_query_term_field_named_term(self, connect, collection): |
||
| 1115 | ''' |
||
| 1116 | method: build query with field named "term" |
||
| 1117 | expected: error raised |
||
| 1118 | ''' |
||
| 1119 | term_fields = add_field_default(default_fields, field_name="term") |
||
| 1120 | collection_term = gen_unique_str("term") |
||
| 1121 | connect.create_collection(collection_term, term_fields) |
||
| 1122 | term_entities = add_field(entities, field_name="term") |
||
| 1123 | ids = connect.insert(collection_term, term_entities) |
||
| 1124 | assert len(ids) == default_nb |
||
| 1125 | connect.flush([collection_term]) |
||
| 1126 | count = connect.count_entities(collection_term) |
||
| 1127 | assert count == default_nb |
||
| 1128 | term_param = {"term": {"term": {"values": [i for i in range(default_nb // 2)]}}} |
||
| 1129 | expr = {"must": [gen_default_vector_expr(default_query), |
||
| 1130 | term_param]} |
||
| 1131 | query = update_query_expr(default_query, expr=expr) |
||
| 1132 | res = connect.search(collection_term, query) |
||
| 1133 | assert len(res) == nq |
||
| 1134 | assert len(res[0]) == default_top_k |
||
| 1135 | connect.drop_collection(collection_term) |
||
| 1136 | |||
| 1137 | @pytest.mark.level(2) |
||
| 1138 | def test_query_term_one_field_not_existed(self, connect, collection): |
||
| 1139 | ''' |
||
| 1140 | method: build query with two fields term, one of it not existed |
||
| 1141 | expected: exception raised |
||
| 1142 | ''' |
||
| 1143 | entities, ids = init_data(connect, collection) |
||
| 1144 | term = gen_default_term_expr() |
||
| 1145 | term["term"].update({"a": [0]}) |
||
| 1146 | expr = {"must": [gen_default_vector_expr(default_query), term]} |
||
| 1147 | query = update_query_expr(default_query, expr=expr) |
||
| 1148 | with pytest.raises(Exception) as e: |
||
| 1149 | res = connect.search(collection, query) |
||
| 1150 | |||
| 1151 | """ |
||
| 1152 | ****************************************************************** |
||
| 1153 | # The following cases are used to build valid range query expr |
||
| 1154 | ****************************************************************** |
||
| 1155 | """ |
||
| 1156 | |||
| 1157 | # TODO |
||
| 1158 | def test_query_range_key_error(self, connect, collection): |
||
| 1159 | ''' |
||
| 1160 | method: build query with range key error |
||
| 1161 | expected: Exception raised |
||
| 1162 | ''' |
||
| 1163 | range = gen_default_range_expr(keyword="ranges") |
||
| 1164 | expr = {"must": [gen_default_vector_expr(default_query), range]} |
||
| 1165 | query = update_query_expr(default_query, expr=expr) |
||
| 1166 | with pytest.raises(Exception) as e: |
||
| 1167 | res = connect.search(collection, query) |
||
| 1168 | |||
| 1169 | @pytest.fixture( |
||
| 1170 | scope="function", |
||
| 1171 | params=gen_invalid_range() |
||
| 1172 | ) |
||
| 1173 | def get_invalid_range(self, request): |
||
| 1174 | return request.param |
||
| 1175 | |||
| 1176 | # TODO |
||
| 1177 | @pytest.mark.level(2) |
||
| 1178 | def test_query_range_wrong_format(self, connect, collection, get_invalid_range): |
||
| 1179 | ''' |
||
| 1180 | method: build query with wrong format range |
||
| 1181 | expected: Exception raised |
||
| 1182 | ''' |
||
| 1183 | entities, ids = init_data(connect, collection) |
||
| 1184 | range = get_invalid_range |
||
| 1185 | expr = {"must": [gen_default_vector_expr(default_query), range]} |
||
| 1186 | query = update_query_expr(default_query, expr=expr) |
||
| 1187 | with pytest.raises(Exception) as e: |
||
| 1188 | res = connect.search(collection, query) |
||
| 1189 | |||
| 1190 | @pytest.mark.level(2) |
||
| 1191 | def test_query_range_string_ranges(self, connect, collection): |
||
| 1192 | ''' |
||
| 1193 | method: build query with invalid ranges |
||
| 1194 | expected: raise Exception |
||
| 1195 | ''' |
||
| 1196 | entities, ids = init_data(connect, collection) |
||
| 1197 | ranges = {"GT": "0", "LT": "1000"} |
||
| 1198 | range = gen_default_range_expr(ranges=ranges) |
||
| 1199 | expr = {"must": [gen_default_vector_expr(default_query), range]} |
||
| 1200 | query = update_query_expr(default_query, expr=expr) |
||
| 1201 | with pytest.raises(Exception) as e: |
||
| 1202 | res = connect.search(collection, query) |
||
| 1203 | |||
| 1204 | @pytest.mark.level(2) |
||
| 1205 | def test_query_range_invalid_ranges(self, connect, collection): |
||
| 1206 | ''' |
||
| 1207 | method: build query with invalid ranges |
||
| 1208 | expected: 0 |
||
| 1209 | ''' |
||
| 1210 | entities, ids = init_data(connect, collection) |
||
| 1211 | ranges = {"GT": default_nb, "LT": 0} |
||
| 1212 | range = gen_default_range_expr(ranges=ranges) |
||
| 1213 | expr = {"must": [gen_default_vector_expr(default_query), range]} |
||
| 1214 | query = update_query_expr(default_query, expr=expr) |
||
| 1215 | res = connect.search(collection, query) |
||
| 1216 | assert len(res[0]) == 0 |
||
| 1217 | |||
| 1218 | @pytest.fixture( |
||
| 1219 | scope="function", |
||
| 1220 | params=gen_valid_ranges() |
||
| 1221 | ) |
||
| 1222 | def get_valid_ranges(self, request): |
||
| 1223 | return request.param |
||
| 1224 | |||
| 1225 | @pytest.mark.level(2) |
||
| 1226 | def test_query_range_valid_ranges(self, connect, collection, get_valid_ranges): |
||
| 1227 | ''' |
||
| 1228 | method: build query with valid ranges |
||
| 1229 | expected: pass |
||
| 1230 | ''' |
||
| 1231 | entities, ids = init_data(connect, collection) |
||
| 1232 | ranges = get_valid_ranges |
||
| 1233 | range = gen_default_range_expr(ranges=ranges) |
||
| 1234 | expr = {"must": [gen_default_vector_expr(default_query), range]} |
||
| 1235 | query = update_query_expr(default_query, expr=expr) |
||
| 1236 | res = connect.search(collection, query) |
||
| 1237 | assert len(res) == nq |
||
| 1238 | assert len(res[0]) == default_top_k |
||
| 1239 | |||
| 1240 | def test_query_range_one_field_not_existed(self, connect, collection): |
||
| 1241 | ''' |
||
| 1242 | method: build query with two fields ranges, one of fields not existed |
||
| 1243 | expected: exception raised |
||
| 1244 | ''' |
||
| 1245 | entities, ids = init_data(connect, collection) |
||
| 1246 | range = gen_default_range_expr() |
||
| 1247 | range["range"].update({"a": {"GT": 1, "LT": default_nb // 2}}) |
||
| 1248 | expr = {"must": [gen_default_vector_expr(default_query), range]} |
||
| 1249 | query = update_query_expr(default_query, expr=expr) |
||
| 1250 | with pytest.raises(Exception) as e: |
||
| 1251 | res = connect.search(collection, query) |
||
| 1252 | |||
| 1253 | """ |
||
| 1254 | ************************************************************************ |
||
| 1255 | # The following cases are used to build query expr multi range and term |
||
| 1256 | ************************************************************************ |
||
| 1257 | """ |
||
| 1258 | |||
| 1259 | # TODO |
||
| 1260 | View Code Duplication | @pytest.mark.level(2) |
|
| 1261 | def test_query_multi_term_has_common(self, connect, collection): |
||
| 1262 | ''' |
||
| 1263 | method: build query with multi term with same field, and values has common |
||
| 1264 | expected: pass |
||
| 1265 | ''' |
||
| 1266 | entities, ids = init_data(connect, collection) |
||
| 1267 | term_first = gen_default_term_expr() |
||
| 1268 | term_second = gen_default_term_expr(values=[i for i in range(default_nb // 3)]) |
||
| 1269 | expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]} |
||
| 1270 | query = update_query_expr(default_query, expr=expr) |
||
| 1271 | res = connect.search(collection, query) |
||
| 1272 | assert len(res) == nq |
||
| 1273 | assert len(res[0]) == default_top_k |
||
| 1274 | |||
| 1275 | # TODO |
||
| 1276 | View Code Duplication | @pytest.mark.level(2) |
|
| 1277 | def test_query_multi_term_no_common(self, connect, collection): |
||
| 1278 | ''' |
||
| 1279 | method: build query with multi range with same field, and ranges no common |
||
| 1280 | expected: pass |
||
| 1281 | ''' |
||
| 1282 | entities, ids = init_data(connect, collection) |
||
| 1283 | term_first = gen_default_term_expr() |
||
| 1284 | term_second = gen_default_term_expr(values=[i for i in range(default_nb // 2, default_nb + default_nb // 2)]) |
||
| 1285 | expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]} |
||
| 1286 | query = update_query_expr(default_query, expr=expr) |
||
| 1287 | res = connect.search(collection, query) |
||
| 1288 | assert len(res) == nq |
||
| 1289 | assert len(res[0]) == 0 |
||
| 1290 | |||
| 1291 | # TODO |
||
| 1292 | View Code Duplication | def test_query_multi_term_different_fields(self, connect, collection): |
|
| 1293 | ''' |
||
| 1294 | method: build query with multi range with same field, and ranges no common |
||
| 1295 | expected: pass |
||
| 1296 | ''' |
||
| 1297 | entities, ids = init_data(connect, collection) |
||
| 1298 | term_first = gen_default_term_expr() |
||
| 1299 | term_second = gen_default_term_expr(field="float", values=[float(i) for i in range(default_nb // 2, default_nb)]) |
||
| 1300 | expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]} |
||
| 1301 | query = update_query_expr(default_query, expr=expr) |
||
| 1302 | res = connect.search(collection, query) |
||
| 1303 | assert len(res) == nq |
||
| 1304 | assert len(res[0]) == 0 |
||
| 1305 | |||
| 1306 | # TODO |
||
| 1307 | @pytest.mark.level(2) |
||
| 1308 | def test_query_single_term_multi_fields(self, connect, collection): |
||
| 1309 | ''' |
||
| 1310 | method: build query with multi term, different field each term |
||
| 1311 | expected: pass |
||
| 1312 | ''' |
||
| 1313 | entities, ids = init_data(connect, collection) |
||
| 1314 | term_first = {"int64": {"values": [i for i in range(default_nb // 2)]}} |
||
| 1315 | term_second = {"float": {"values": [float(i) for i in range(default_nb // 2, default_nb)]}} |
||
| 1316 | term = update_term_expr({"term": {}}, [term_first, term_second]) |
||
| 1317 | expr = {"must": [gen_default_vector_expr(default_query), term]} |
||
| 1318 | query = update_query_expr(default_query, expr=expr) |
||
| 1319 | with pytest.raises(Exception) as e: |
||
| 1320 | res = connect.search(collection, query) |
||
| 1321 | |||
| 1322 | # TODO |
||
| 1323 | @pytest.mark.level(2) |
||
| 1324 | def test_query_multi_range_has_common(self, connect, collection): |
||
| 1325 | ''' |
||
| 1326 | method: build query with multi range with same field, and ranges has common |
||
| 1327 | expected: pass |
||
| 1328 | ''' |
||
| 1329 | entities, ids = init_data(connect, collection) |
||
| 1330 | range_one = gen_default_range_expr() |
||
| 1331 | range_two = gen_default_range_expr(ranges={"GT": 1, "LT": default_nb // 3}) |
||
| 1332 | expr = {"must": [gen_default_vector_expr(default_query), range_one, range_two]} |
||
| 1333 | query = update_query_expr(default_query, expr=expr) |
||
| 1334 | res = connect.search(collection, query) |
||
| 1335 | assert len(res) == nq |
||
| 1336 | assert len(res[0]) == default_top_k |
||
| 1337 | |||
| 1338 | # TODO |
||
| 1339 | @pytest.mark.level(2) |
||
| 1340 | def test_query_multi_range_no_common(self, connect, collection): |
||
| 1341 | ''' |
||
| 1342 | method: build query with multi range with same field, and ranges no common |
||
| 1343 | expected: pass |
||
| 1344 | ''' |
||
| 1345 | entities, ids = init_data(connect, collection) |
||
| 1346 | range_one = gen_default_range_expr() |
||
| 1347 | range_two = gen_default_range_expr(ranges={"GT": default_nb // 2, "LT": default_nb}) |
||
| 1348 | expr = {"must": [gen_default_vector_expr(default_query), range_one, range_two]} |
||
| 1349 | query = update_query_expr(default_query, expr=expr) |
||
| 1350 | res = connect.search(collection, query) |
||
| 1351 | assert len(res) == nq |
||
| 1352 | assert len(res[0]) == 0 |
||
| 1353 | |||
| 1354 | # TODO |
||
| 1355 | @pytest.mark.level(2) |
||
| 1356 | def test_query_multi_range_different_fields(self, connect, collection): |
||
| 1357 | ''' |
||
| 1358 | method: build query with multi range, different field each range |
||
| 1359 | expected: pass |
||
| 1360 | ''' |
||
| 1361 | entities, ids = init_data(connect, collection) |
||
| 1362 | range_first = gen_default_range_expr() |
||
| 1363 | range_second = gen_default_range_expr(field="float", ranges={"GT": default_nb // 2, "LT": default_nb}) |
||
| 1364 | expr = {"must": [gen_default_vector_expr(default_query), range_first, range_second]} |
||
| 1365 | query = update_query_expr(default_query, expr=expr) |
||
| 1366 | res = connect.search(collection, query) |
||
| 1367 | assert len(res) == nq |
||
| 1368 | assert len(res[0]) == 0 |
||
| 1369 | |||
| 1370 | # TODO |
||
| 1371 | @pytest.mark.level(2) |
||
| 1372 | def test_query_single_range_multi_fields(self, connect, collection): |
||
| 1373 | ''' |
||
| 1374 | method: build query with multi range, different field each range |
||
| 1375 | expected: pass |
||
| 1376 | ''' |
||
| 1377 | entities, ids = init_data(connect, collection) |
||
| 1378 | range_first = {"int64": {"GT": 0, "LT": default_nb // 2}} |
||
| 1379 | range_second = {"float": {"GT": default_nb / 2, "LT": float(default_nb)}} |
||
| 1380 | range = update_range_expr({"range": {}}, [range_first, range_second]) |
||
| 1381 | expr = {"must": [gen_default_vector_expr(default_query), range]} |
||
| 1382 | query = update_query_expr(default_query, expr=expr) |
||
| 1383 | with pytest.raises(Exception) as e: |
||
| 1384 | res = connect.search(collection, query) |
||
| 1385 | |||
| 1386 | """ |
||
| 1387 | ****************************************************************** |
||
| 1388 | # The following cases are used to build query expr both term and range |
||
| 1389 | ****************************************************************** |
||
| 1390 | """ |
||
| 1391 | |||
| 1392 | # TODO |
||
| 1393 | View Code Duplication | @pytest.mark.level(2) |
|
| 1394 | def test_query_single_term_range_has_common(self, connect, collection): |
||
| 1395 | ''' |
||
| 1396 | method: build query with single term single range |
||
| 1397 | expected: pass |
||
| 1398 | ''' |
||
| 1399 | entities, ids = init_data(connect, collection) |
||
| 1400 | term = gen_default_term_expr() |
||
| 1401 | range = gen_default_range_expr(ranges={"GT": -1, "LT": default_nb // 2}) |
||
| 1402 | expr = {"must": [gen_default_vector_expr(default_query), term, range]} |
||
| 1403 | query = update_query_expr(default_query, expr=expr) |
||
| 1404 | res = connect.search(collection, query) |
||
| 1405 | assert len(res) == nq |
||
| 1406 | assert len(res[0]) == default_top_k |
||
| 1407 | |||
| 1408 | # TODO |
||
| 1409 | View Code Duplication | def test_query_single_term_range_no_common(self, connect, collection): |
|
| 1410 | ''' |
||
| 1411 | method: build query with single term single range |
||
| 1412 | expected: pass |
||
| 1413 | ''' |
||
| 1414 | entities, ids = init_data(connect, collection) |
||
| 1415 | term = gen_default_term_expr() |
||
| 1416 | range = gen_default_range_expr(ranges={"GT": default_nb // 2, "LT": default_nb}) |
||
| 1417 | expr = {"must": [gen_default_vector_expr(default_query), term, range]} |
||
| 1418 | query = update_query_expr(default_query, expr=expr) |
||
| 1419 | res = connect.search(collection, query) |
||
| 1420 | assert len(res) == nq |
||
| 1421 | assert len(res[0]) == 0 |
||
| 1422 | |||
| 1423 | """ |
||
| 1424 | ****************************************************************** |
||
| 1425 | # The following cases are used to build multi vectors query expr |
||
| 1426 | ****************************************************************** |
||
| 1427 | """ |
||
| 1428 | |||
| 1429 | # TODO |
||
| 1430 | def test_query_multi_vectors_same_field(self, connect, collection): |
||
| 1431 | ''' |
||
| 1432 | method: build query with two vectors same field |
||
| 1433 | expected: error raised |
||
| 1434 | ''' |
||
| 1435 | entities, ids = init_data(connect, collection) |
||
| 1436 | vector1 = default_query |
||
| 1437 | vector2 = gen_query_vectors(field_name, entities, default_top_k, nq=2) |
||
| 1438 | expr = { |
||
| 1439 | "must": [vector1, vector2] |
||
| 1440 | } |
||
| 1441 | query = update_query_expr(default_query, expr=expr) |
||
| 1442 | with pytest.raises(Exception) as e: |
||
| 1443 | res = connect.search(collection, query) |
||
| 1444 | |||
| 1445 | |||
| 1446 | class TestSearchDSLBools(object): |
||
| 1447 | """ |
||
| 1448 | ****************************************************************** |
||
| 1449 | # The following cases are used to build invalid query expr |
||
| 1450 | ****************************************************************** |
||
| 1451 | """ |
||
| 1452 | |||
| 1453 | @pytest.mark.level(2) |
||
| 1454 | def test_query_no_bool(self, connect, collection): |
||
| 1455 | ''' |
||
| 1456 | method: build query without bool expr |
||
| 1457 | expected: error raised |
||
| 1458 | ''' |
||
| 1459 | entities, ids = init_data(connect, collection) |
||
| 1460 | expr = {"bool1": {}} |
||
| 1461 | query = expr |
||
| 1462 | with pytest.raises(Exception) as e: |
||
| 1463 | res = connect.search(collection, query) |
||
| 1464 | |||
| 1465 | def test_query_should_only_term(self, connect, collection): |
||
| 1466 | ''' |
||
| 1467 | method: build query without must, with should.term instead |
||
| 1468 | expected: error raised |
||
| 1469 | ''' |
||
| 1470 | expr = {"should": gen_default_term_expr} |
||
| 1471 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
| 1472 | with pytest.raises(Exception) as e: |
||
| 1473 | res = connect.search(collection, query) |
||
| 1474 | |||
| 1475 | def test_query_should_only_vector(self, connect, collection): |
||
| 1476 | ''' |
||
| 1477 | method: build query without must, with should.vector instead |
||
| 1478 | expected: error raised |
||
| 1479 | ''' |
||
| 1480 | expr = {"should": default_query["bool"]["must"]} |
||
| 1481 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
| 1482 | with pytest.raises(Exception) as e: |
||
| 1483 | res = connect.search(collection, query) |
||
| 1484 | |||
| 1485 | def test_query_must_not_only_term(self, connect, collection): |
||
| 1486 | ''' |
||
| 1487 | method: build query without must, with must_not.term instead |
||
| 1488 | expected: error raised |
||
| 1489 | ''' |
||
| 1490 | expr = {"must_not": gen_default_term_expr} |
||
| 1491 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
| 1492 | with pytest.raises(Exception) as e: |
||
| 1493 | res = connect.search(collection, query) |
||
| 1494 | |||
| 1495 | def test_query_must_not_vector(self, connect, collection): |
||
| 1496 | ''' |
||
| 1497 | method: build query without must, with must_not.vector instead |
||
| 1498 | expected: error raised |
||
| 1499 | ''' |
||
| 1500 | expr = {"must_not": default_query["bool"]["must"]} |
||
| 1501 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
| 1502 | with pytest.raises(Exception) as e: |
||
| 1503 | res = connect.search(collection, query) |
||
| 1504 | |||
| 1505 | def test_query_must_should(self, connect, collection): |
||
| 1506 | ''' |
||
| 1507 | method: build query must, and with should.term |
||
| 1508 | expected: error raised |
||
| 1509 | ''' |
||
| 1510 | expr = {"should": gen_default_term_expr} |
||
| 1511 | query = update_query_expr(default_query, keep_old=True, expr=expr) |
||
| 1512 | with pytest.raises(Exception) as e: |
||
| 1513 | res = connect.search(collection, query) |
||
| 1514 | |||
| 1515 | |||
| 1516 | """ |
||
| 1517 | ****************************************************************** |
||
| 1518 | # The following cases are used to test `search` function |
||
| 1519 | # with invalid collection_name, or invalid query expr |
||
| 1520 | ****************************************************************** |
||
| 1521 | """ |
||
| 1522 | |||
| 1523 | |||
| 1524 | class TestSearchInvalid(object): |
||
| 1525 | """ |
||
| 1526 | Test search collection with invalid collection names |
||
| 1527 | """ |
||
| 1528 | |||
| 1529 | @pytest.fixture( |
||
| 1530 | scope="function", |
||
| 1531 | params=gen_invalid_strs() |
||
| 1532 | ) |
||
| 1533 | def get_collection_name(self, request): |
||
| 1534 | yield request.param |
||
| 1535 | |||
| 1536 | @pytest.fixture( |
||
| 1537 | scope="function", |
||
| 1538 | params=gen_invalid_strs() |
||
| 1539 | ) |
||
| 1540 | def get_invalid_tag(self, request): |
||
| 1541 | yield request.param |
||
| 1542 | |||
| 1543 | @pytest.fixture( |
||
| 1544 | scope="function", |
||
| 1545 | params=gen_invalid_strs() |
||
| 1546 | ) |
||
| 1547 | def get_invalid_field(self, request): |
||
| 1548 | yield request.param |
||
| 1549 | |||
| 1550 | @pytest.fixture( |
||
| 1551 | scope="function", |
||
| 1552 | params=gen_simple_index() |
||
| 1553 | ) |
||
| 1554 | def get_simple_index(self, request, connect): |
||
| 1555 | if str(connect._cmd("mode")) == "CPU": |
||
| 1556 | if request.param["index_type"] in index_cpu_not_support(): |
||
| 1557 | pytest.skip("sq8h not support in CPU mode") |
||
| 1558 | return request.param |
||
| 1559 | |||
| 1560 | @pytest.mark.level(2) |
||
| 1561 | def test_search_with_invalid_collection(self, connect, get_collection_name): |
||
| 1562 | collection_name = get_collection_name |
||
| 1563 | with pytest.raises(Exception) as e: |
||
| 1564 | res = connect.search(collection_name, default_query) |
||
| 1565 | |||
| 1566 | # TODO(yukun) |
||
| 1567 | @pytest.mark.level(2) |
||
| 1568 | def _test_search_with_invalid_tag(self, connect, collection): |
||
| 1569 | tag = " " |
||
| 1570 | with pytest.raises(Exception) as e: |
||
| 1571 | res = connect.search(collection, default_query, partition_tags=tag) |
||
| 1572 | |||
| 1573 | @pytest.mark.level(2) |
||
| 1574 | def test_search_with_invalid_field_name(self, connect, collection, get_invalid_field): |
||
| 1575 | fields = [get_invalid_field] |
||
| 1576 | with pytest.raises(Exception) as e: |
||
| 1577 | res = connect.search(collection, default_query, fields=fields) |
||
| 1578 | |||
| 1579 | @pytest.mark.level(1) |
||
| 1580 | def test_search_with_not_existed_field_name(self, connect, collection): |
||
| 1581 | fields = [gen_unique_str("field_name")] |
||
| 1582 | with pytest.raises(Exception) as e: |
||
| 1583 | res = connect.search(collection, default_query, fields=fields) |
||
| 1584 | |||
| 1585 | """ |
||
| 1586 | Test search collection with invalid query |
||
| 1587 | """ |
||
| 1588 | |||
| 1589 | @pytest.fixture( |
||
| 1590 | scope="function", |
||
| 1591 | params=gen_invalid_ints() |
||
| 1592 | ) |
||
| 1593 | def get_top_k(self, request): |
||
| 1594 | yield request.param |
||
| 1595 | |||
| 1596 | @pytest.mark.level(1) |
||
| 1597 | def test_search_with_invalid_top_k(self, connect, collection, get_top_k): |
||
| 1598 | ''' |
||
| 1599 | target: test search function, with the wrong top_k |
||
| 1600 | method: search with top_k |
||
| 1601 | expected: raise an error, and the connection is normal |
||
| 1602 | ''' |
||
| 1603 | top_k = get_top_k |
||
| 1604 | default_query["bool"]["must"][0]["vector"][field_name]["topk"] = top_k |
||
| 1605 | with pytest.raises(Exception) as e: |
||
| 1606 | res = connect.search(collection, default_query) |
||
| 1607 | |||
| 1608 | """ |
||
| 1609 | Test search collection with invalid search params |
||
| 1610 | """ |
||
| 1611 | |||
| 1612 | @pytest.fixture( |
||
| 1613 | scope="function", |
||
| 1614 | params=gen_invaild_search_params() |
||
| 1615 | ) |
||
| 1616 | def get_search_params(self, request): |
||
| 1617 | yield request.param |
||
| 1618 | |||
| 1619 | View Code Duplication | @pytest.mark.level(2) |
|
| 1620 | def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params): |
||
| 1621 | ''' |
||
| 1622 | target: test search function, with the wrong nprobe |
||
| 1623 | method: search with nprobe |
||
| 1624 | expected: raise an error, and the connection is normal |
||
| 1625 | ''' |
||
| 1626 | search_params = get_search_params |
||
| 1627 | index_type = get_simple_index["index_type"] |
||
| 1628 | if index_type in ["FLAT"]: |
||
| 1629 | pytest.skip("skip in FLAT index") |
||
| 1630 | if index_type != search_params["index_type"]: |
||
| 1631 | pytest.skip("skip if index_type not matched") |
||
| 1632 | entities, ids = init_data(connect, collection) |
||
| 1633 | connect.create_index(collection, field_name, get_simple_index) |
||
| 1634 | query, vecs = gen_query_vectors(field_name, entities, default_top_k, 1, search_params=search_params["search_params"]) |
||
| 1635 | with pytest.raises(Exception) as e: |
||
| 1636 | res = connect.search(collection, query) |
||
| 1637 | |||
| 1638 | @pytest.mark.level(2) |
||
| 1639 | def test_search_with_invalid_params_binary(self, connect, binary_collection): |
||
| 1640 | ''' |
||
| 1641 | target: test search function, with the wrong nprobe |
||
| 1642 | method: search with nprobe |
||
| 1643 | expected: raise an error, and the connection is normal |
||
| 1644 | ''' |
||
| 1645 | nq = 1 |
||
| 1646 | index_type = "BIN_IVF_FLAT" |
||
| 1647 | int_vectors, entities, ids = init_binary_data(connect, binary_collection) |
||
| 1648 | query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
| 1649 | connect.create_index(binary_collection, binary_field_name, {"index_type": index_type, "metric_type": "JACCARD", "params": {"nlist": 128}}) |
||
| 1650 | query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, search_params={"nprobe": 0}, metric_type="JACCARD") |
||
| 1651 | with pytest.raises(Exception) as e: |
||
| 1652 | res = connect.search(binary_collection, query) |
||
| 1653 | |||
| 1654 | View Code Duplication | @pytest.mark.level(2) |
|
| 1655 | def test_search_with_empty_params(self, connect, collection, args, get_simple_index): |
||
| 1656 | ''' |
||
| 1657 | target: test search function, with empty search params |
||
| 1658 | method: search with params |
||
| 1659 | expected: raise an error, and the connection is normal |
||
| 1660 | ''' |
||
| 1661 | index_type = get_simple_index["index_type"] |
||
| 1662 | if args["handler"] == "HTTP": |
||
| 1663 | pytest.skip("skip in http mode") |
||
| 1664 | if index_type == "FLAT": |
||
| 1665 | pytest.skip("skip in FLAT index") |
||
| 1666 | entities, ids = init_data(connect, collection) |
||
| 1667 | connect.create_index(collection, field_name, get_simple_index) |
||
| 1668 | query, vecs = gen_query_vectors(field_name, entities, default_top_k, 1, search_params={}) |
||
| 1669 | with pytest.raises(Exception) as e: |
||
| 1670 | res = connect.search(collection, query) |
||
| 1671 | |||
| 1672 | |||
| 1673 | def check_id_result(result, id): |
||
| 1674 | limit_in = 5 |
||
| 1675 | ids = [entity.id for entity in result] |
||
| 1676 | if len(result) >= limit_in: |
||
| 1677 | return id in ids[:limit_in] |
||
| 1678 | else: |
||
| 1679 | return id in ids |
||
| 1680 |