| Total Complexity | 158 |
| Total Lines | 1273 |
| Duplicated Lines | 27.1 % |
| Changes | 0 | ||
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like test_search often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | import time |
||
| 2 | import pdb |
||
| 3 | import copy |
||
| 4 | import threading |
||
| 5 | import logging |
||
| 6 | from multiprocessing import Pool, Process |
||
| 7 | import pytest |
||
| 8 | import numpy as np |
||
| 9 | |||
| 10 | from milvus import DataType |
||
| 11 | from utils import * |
||
| 12 | |||
| 13 | dim = 128 |
||
| 14 | segment_row_count = 5000 |
||
| 15 | top_k_limit = 2048 |
||
| 16 | collection_id = "search" |
||
| 17 | tag = "1970-01-01" |
||
| 18 | insert_interval_time = 1.5 |
||
| 19 | nb = 6000 |
||
| 20 | top_k = 10 |
||
| 21 | nq = 1 |
||
| 22 | nprobe = 1 |
||
| 23 | epsilon = 0.001 |
||
| 24 | field_name = default_float_vec_field_name |
||
|
|
|||
| 25 | default_fields = gen_default_fields() |
||
| 26 | search_param = {"nprobe": 1} |
||
| 27 | entity = gen_entities(1, is_normal=True) |
||
| 28 | raw_vector, binary_entity = gen_binary_entities(1) |
||
| 29 | entities = gen_entities(nb, is_normal=True) |
||
| 30 | raw_vectors, binary_entities = gen_binary_entities(nb) |
||
| 31 | default_query, default_query_vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
| 32 | |||
| 33 | |||
| 34 | def init_data(connect, collection, nb=6000, partition_tags=None): |
||
| 35 | ''' |
||
| 36 | Generate entities and add it in collection |
||
| 37 | ''' |
||
| 38 | global entities |
||
| 39 | if nb == 6000: |
||
| 40 | insert_entities = entities |
||
| 41 | else: |
||
| 42 | insert_entities = gen_entities(nb, is_normal=True) |
||
| 43 | if partition_tags is None: |
||
| 44 | ids = connect.insert(collection, insert_entities) |
||
| 45 | else: |
||
| 46 | ids = connect.insert(collection, insert_entities, partition_tag=partition_tags) |
||
| 47 | connect.flush([collection]) |
||
| 48 | return insert_entities, ids |
||
| 49 | |||
| 50 | |||
| 51 | def init_binary_data(connect, collection, nb=6000, insert=True, partition_tags=None): |
||
| 52 | ''' |
||
| 53 | Generate entities and add it in collection |
||
| 54 | ''' |
||
| 55 | ids = [] |
||
| 56 | global binary_entities |
||
| 57 | global raw_vectors |
||
| 58 | if nb == 6000: |
||
| 59 | insert_entities = binary_entities |
||
| 60 | insert_raw_vectors = raw_vectors |
||
| 61 | else: |
||
| 62 | insert_raw_vectors, insert_entities = gen_binary_entities(nb) |
||
| 63 | if insert is True: |
||
| 64 | if partition_tags is None: |
||
| 65 | ids = connect.insert(collection, insert_entities) |
||
| 66 | else: |
||
| 67 | ids = connect.insert(collection, insert_entities, partition_tag=partition_tags) |
||
| 68 | connect.flush([collection]) |
||
| 69 | return insert_raw_vectors, insert_entities, ids |
||
| 70 | |||
| 71 | |||
| 72 | class TestSearchBase: |
||
| 73 | """ |
||
| 74 | generate valid create_index params |
||
| 75 | """ |
||
| 76 | |||
| 77 | @pytest.fixture( |
||
| 78 | scope="function", |
||
| 79 | params=gen_index() |
||
| 80 | ) |
||
| 81 | def get_index(self, request, connect): |
||
| 82 | if str(connect._cmd("mode")) == "CPU": |
||
| 83 | if request.param["index_type"] in index_cpu_not_support(): |
||
| 84 | pytest.skip("sq8h not support in CPU mode") |
||
| 85 | return request.param |
||
| 86 | |||
| 87 | @pytest.fixture( |
||
| 88 | scope="function", |
||
| 89 | params=gen_simple_index() |
||
| 90 | ) |
||
| 91 | def get_simple_index(self, request, connect): |
||
| 92 | if str(connect._cmd("mode")) == "CPU": |
||
| 93 | if request.param["index_type"] in index_cpu_not_support(): |
||
| 94 | pytest.skip("sq8h not support in CPU mode") |
||
| 95 | return request.param |
||
| 96 | |||
| 97 | @pytest.fixture( |
||
| 98 | scope="function", |
||
| 99 | params=gen_simple_index() |
||
| 100 | ) |
||
| 101 | def get_jaccard_index(self, request, connect): |
||
| 102 | logging.getLogger().info(request.param) |
||
| 103 | if request.param["index_type"] in binary_support(): |
||
| 104 | return request.param |
||
| 105 | else: |
||
| 106 | pytest.skip("Skip index Temporary") |
||
| 107 | |||
| 108 | @pytest.fixture( |
||
| 109 | scope="function", |
||
| 110 | params=gen_simple_index() |
||
| 111 | ) |
||
| 112 | def get_hamming_index(self, request, connect): |
||
| 113 | logging.getLogger().info(request.param) |
||
| 114 | if request.param["index_type"] in binary_support(): |
||
| 115 | return request.param |
||
| 116 | else: |
||
| 117 | pytest.skip("Skip index Temporary") |
||
| 118 | |||
| 119 | @pytest.fixture( |
||
| 120 | scope="function", |
||
| 121 | params=gen_simple_index() |
||
| 122 | ) |
||
| 123 | def get_structure_index(self, request, connect): |
||
| 124 | logging.getLogger().info(request.param) |
||
| 125 | if request.param["index_type"] == "FLAT": |
||
| 126 | return request.param |
||
| 127 | else: |
||
| 128 | pytest.skip("Skip index Temporary") |
||
| 129 | |||
| 130 | """ |
||
| 131 | generate top-k params |
||
| 132 | """ |
||
| 133 | |||
| 134 | @pytest.fixture( |
||
| 135 | scope="function", |
||
| 136 | params=[1, 10, 2049] |
||
| 137 | ) |
||
| 138 | def get_top_k(self, request): |
||
| 139 | yield request.param |
||
| 140 | |||
| 141 | @pytest.fixture( |
||
| 142 | scope="function", |
||
| 143 | params=[1, 10, 1100] |
||
| 144 | ) |
||
| 145 | def get_nq(self, request): |
||
| 146 | yield request.param |
||
| 147 | |||
| 148 | def test_search_flat(self, connect, collection, get_top_k, get_nq): |
||
| 149 | ''' |
||
| 150 | target: test basic search fuction, all the search params is corrent, change top-k value |
||
| 151 | method: search with the given vectors, check the result |
||
| 152 | expected: the length of the result is top_k |
||
| 153 | ''' |
||
| 154 | top_k = get_top_k |
||
| 155 | nq = get_nq |
||
| 156 | entities, ids = init_data(connect, collection) |
||
| 157 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
| 158 | if top_k <= top_k_limit: |
||
| 159 | res = connect.search(collection, query) |
||
| 160 | assert len(res[0]) == top_k |
||
| 161 | assert res[0]._distances[0] <= epsilon |
||
| 162 | assert check_id_result(res[0], ids[0]) |
||
| 163 | else: |
||
| 164 | with pytest.raises(Exception) as e: |
||
| 165 | res = connect.search(collection, query) |
||
| 166 | |||
| 167 | def test_search_field(self, connect, collection, get_top_k, get_nq): |
||
| 168 | ''' |
||
| 169 | target: test basic search fuction, all the search params is corrent, change top-k value |
||
| 170 | method: search with the given vectors, check the result |
||
| 171 | expected: the length of the result is top_k |
||
| 172 | ''' |
||
| 173 | top_k = get_top_k |
||
| 174 | nq = get_nq |
||
| 175 | entities, ids = init_data(connect, collection) |
||
| 176 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
| 177 | if top_k <= top_k_limit: |
||
| 178 | res = connect.search(collection, query, fields=["float_vector"]) |
||
| 179 | assert len(res[0]) == top_k |
||
| 180 | assert res[0]._distances[0] <= epsilon |
||
| 181 | assert check_id_result(res[0], ids[0]) |
||
| 182 | # TODO |
||
| 183 | res = connect.search(collection, query, fields=["float"]) |
||
| 184 | # TODO |
||
| 185 | else: |
||
| 186 | with pytest.raises(Exception) as e: |
||
| 187 | res = connect.search(collection, query) |
||
| 188 | |||
| 189 | @pytest.mark.level(2) |
||
| 190 | def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
| 191 | ''' |
||
| 192 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
| 193 | method: search with the given vectors, check the result |
||
| 194 | expected: the length of the result is top_k |
||
| 195 | ''' |
||
| 196 | top_k = get_top_k |
||
| 197 | nq = get_nq |
||
| 198 | |||
| 199 | index_type = get_simple_index["index_type"] |
||
| 200 | if index_type == "IVF_PQ": |
||
| 201 | pytest.skip("Skip PQ") |
||
| 202 | entities, ids = init_data(connect, collection) |
||
| 203 | connect.create_index(collection, field_name, get_simple_index) |
||
| 204 | search_param = get_search_param(index_type) |
||
| 205 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
| 206 | if top_k > top_k_limit: |
||
| 207 | with pytest.raises(Exception) as e: |
||
| 208 | res = connect.search(collection, query) |
||
| 209 | else: |
||
| 210 | res = connect.search(collection, query) |
||
| 211 | assert len(res) == nq |
||
| 212 | assert len(res[0]) >= top_k |
||
| 213 | assert res[0]._distances[0] < epsilon |
||
| 214 | assert check_id_result(res[0], ids[0]) |
||
| 215 | |||
| 216 | def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
| 217 | ''' |
||
| 218 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
| 219 | method: add vectors into collection, search with the given vectors, check the result |
||
| 220 | expected: the length of the result is top_k, search collection with partition tag return empty |
||
| 221 | ''' |
||
| 222 | top_k = get_top_k |
||
| 223 | nq = get_nq |
||
| 224 | |||
| 225 | index_type = get_simple_index["index_type"] |
||
| 226 | if index_type == "IVF_PQ": |
||
| 227 | pytest.skip("Skip PQ") |
||
| 228 | connect.create_partition(collection, tag) |
||
| 229 | entities, ids = init_data(connect, collection) |
||
| 230 | connect.create_index(collection, field_name, get_simple_index) |
||
| 231 | search_param = get_search_param(index_type) |
||
| 232 | query, vecs = gen_query_vectors_(field_name, entities, top_k, nq, search_params=search_param) |
||
| 233 | if top_k > top_k_limit: |
||
| 234 | with pytest.raises(Exception) as e: |
||
| 235 | res = connect.search(collection, query) |
||
| 236 | else: |
||
| 237 | res = connect.search(collection, query) |
||
| 238 | assert len(res) == nq |
||
| 239 | assert len(res[0]) >= top_k |
||
| 240 | assert res[0]._distances[0] < epsilon |
||
| 241 | assert check_id_result(res[0], ids[0]) |
||
| 242 | res = connect.search(collection, query, partition_tags=[tag]) |
||
| 243 | assert len(res) == nq |
||
| 244 | |||
| 245 | def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
| 246 | ''' |
||
| 247 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
| 248 | method: search with the given vectors, check the result |
||
| 249 | expected: the length of the result is top_k |
||
| 250 | ''' |
||
| 251 | top_k = get_top_k |
||
| 252 | nq = get_nq |
||
| 253 | |||
| 254 | index_type = get_simple_index["index_type"] |
||
| 255 | if index_type == "IVF_PQ": |
||
| 256 | pytest.skip("Skip PQ") |
||
| 257 | connect.create_partition(collection, tag) |
||
| 258 | entities, ids = init_data(connect, collection, partition_tags=tag) |
||
| 259 | connect.create_index(collection, field_name, get_simple_index) |
||
| 260 | search_param = get_search_param(index_type) |
||
| 261 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
| 262 | for tags in [[tag], [tag, "new_tag"]]: |
||
| 263 | if top_k > top_k_limit: |
||
| 264 | with pytest.raises(Exception) as e: |
||
| 265 | res = connect.search(collection, query, partition_tags=tags) |
||
| 266 | else: |
||
| 267 | res = connect.search(collection, query, partition_tags=tags) |
||
| 268 | assert len(res) == nq |
||
| 269 | assert len(res[0]) >= top_k |
||
| 270 | assert res[0]._distances[0] < epsilon |
||
| 271 | assert check_id_result(res[0], ids[0]) |
||
| 272 | |||
| 273 | @pytest.mark.level(2) |
||
| 274 | def test_search_index_partition_C(self, connect, collection, get_top_k, get_nq): |
||
| 275 | ''' |
||
| 276 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
| 277 | method: search with the given vectors and tag (tag name not existed in collection), check the result |
||
| 278 | expected: error raised |
||
| 279 | ''' |
||
| 280 | top_k = get_top_k |
||
| 281 | nq = get_nq |
||
| 282 | entities, ids = init_data(connect, collection) |
||
| 283 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
| 284 | if top_k > top_k_limit: |
||
| 285 | with pytest.raises(Exception) as e: |
||
| 286 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
| 287 | else: |
||
| 288 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
| 289 | assert len(res) == nq |
||
| 290 | assert len(res[0]) == 0 |
||
| 291 | |||
| 292 | View Code Duplication | @pytest.mark.level(2) |
|
| 293 | def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k): |
||
| 294 | ''' |
||
| 295 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
| 296 | method: search collection with the given vectors and tags, check the result |
||
| 297 | expected: the length of the result is top_k |
||
| 298 | ''' |
||
| 299 | top_k = get_top_k |
||
| 300 | nq = 2 |
||
| 301 | new_tag = "new_tag" |
||
| 302 | index_type = get_simple_index["index_type"] |
||
| 303 | if index_type == "IVF_PQ": |
||
| 304 | pytest.skip("Skip PQ") |
||
| 305 | connect.create_partition(collection, tag) |
||
| 306 | connect.create_partition(collection, new_tag) |
||
| 307 | entities, ids = init_data(connect, collection, partition_tags=tag) |
||
| 308 | new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) |
||
| 309 | connect.create_index(collection, field_name, get_simple_index) |
||
| 310 | search_param = get_search_param(index_type) |
||
| 311 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
| 312 | if top_k > top_k_limit: |
||
| 313 | with pytest.raises(Exception) as e: |
||
| 314 | res = connect.search(collection, query) |
||
| 315 | else: |
||
| 316 | res = connect.search(collection, query) |
||
| 317 | assert check_id_result(res[0], ids[0]) |
||
| 318 | assert not check_id_result(res[1], new_ids[0]) |
||
| 319 | assert res[0]._distances[0] < epsilon |
||
| 320 | assert res[1]._distances[0] < epsilon |
||
| 321 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
| 322 | assert res[0]._distances[0] > epsilon |
||
| 323 | assert res[1]._distances[0] > epsilon |
||
| 324 | |||
| 325 | # TODO: |
||
| 326 | View Code Duplication | @pytest.mark.level(2) |
|
| 327 | def _test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k): |
||
| 328 | ''' |
||
| 329 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
| 330 | method: search collection with the given vectors and tags, check the result |
||
| 331 | expected: the length of the result is top_k |
||
| 332 | ''' |
||
| 333 | top_k = get_top_k |
||
| 334 | nq = 2 |
||
| 335 | tag = "tag" |
||
| 336 | new_tag = "new_tag" |
||
| 337 | index_type = get_simple_index["index_type"] |
||
| 338 | if index_type == "IVF_PQ": |
||
| 339 | pytest.skip("Skip PQ") |
||
| 340 | connect.create_partition(collection, tag) |
||
| 341 | connect.create_partition(collection, new_tag) |
||
| 342 | entities, ids = init_data(connect, collection, partition_tags=tag) |
||
| 343 | new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) |
||
| 344 | connect.create_index(collection, field_name, get_simple_index) |
||
| 345 | search_param = get_search_param(index_type) |
||
| 346 | query, vecs = gen_query_vectors(field_name, new_entities, top_k, nq, search_params=search_param) |
||
| 347 | if top_k > top_k_limit: |
||
| 348 | with pytest.raises(Exception) as e: |
||
| 349 | res = connect.search(collection, query) |
||
| 350 | else: |
||
| 351 | res = connect.search(collection, query, partition_tags=["(.*)tag"]) |
||
| 352 | assert not check_id_result(res[0], ids[0]) |
||
| 353 | assert check_id_result(res[1], new_ids[0]) |
||
| 354 | assert res[0]._distances[0] > epsilon |
||
| 355 | assert res[1]._distances[0] < epsilon |
||
| 356 | res = connect.search(collection, query, partition_tags=["new(.*)"]) |
||
| 357 | assert res[0]._distances[0] > epsilon |
||
| 358 | assert res[1]._distances[0] < epsilon |
||
| 359 | |||
| 360 | # |
||
| 361 | # test for ip metric |
||
| 362 | # |
||
| 363 | @pytest.mark.level(2) |
||
| 364 | def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
| 365 | ''' |
||
| 366 | target: test basic search fuction, all the search params is corrent, change top-k value |
||
| 367 | method: search with the given vectors, check the result |
||
| 368 | expected: the length of the result is top_k |
||
| 369 | ''' |
||
| 370 | top_k = get_top_k |
||
| 371 | nq = get_nq |
||
| 372 | entities, ids = init_data(connect, collection) |
||
| 373 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP") |
||
| 374 | if top_k <= top_k_limit: |
||
| 375 | res = connect.search(collection, query) |
||
| 376 | assert len(res[0]) == top_k |
||
| 377 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
| 378 | assert check_id_result(res[0], ids[0]) |
||
| 379 | else: |
||
| 380 | with pytest.raises(Exception) as e: |
||
| 381 | res = connect.search(collection, query) |
||
| 382 | |||
| 383 | def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
| 384 | ''' |
||
| 385 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
| 386 | method: search with the given vectors, check the result |
||
| 387 | expected: the length of the result is top_k |
||
| 388 | ''' |
||
| 389 | top_k = get_top_k |
||
| 390 | nq = get_nq |
||
| 391 | |||
| 392 | index_type = get_simple_index["index_type"] |
||
| 393 | if index_type == "IVF_PQ": |
||
| 394 | pytest.skip("Skip PQ") |
||
| 395 | entities, ids = init_data(connect, collection) |
||
| 396 | get_simple_index["metric_type"] = "IP" |
||
| 397 | connect.create_index(collection, field_name, get_simple_index) |
||
| 398 | search_param = get_search_param(index_type) |
||
| 399 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param) |
||
| 400 | if top_k > top_k_limit: |
||
| 401 | with pytest.raises(Exception) as e: |
||
| 402 | res = connect.search(collection, query) |
||
| 403 | else: |
||
| 404 | res = connect.search(collection, query) |
||
| 405 | assert len(res) == nq |
||
| 406 | assert len(res[0]) >= top_k |
||
| 407 | assert check_id_result(res[0], ids[0]) |
||
| 408 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
| 409 | |||
| 410 | @pytest.mark.level(2) |
||
| 411 | def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
| 412 | ''' |
||
| 413 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
| 414 | method: add vectors into collection, search with the given vectors, check the result |
||
| 415 | expected: the length of the result is top_k, search collection with partition tag return empty |
||
| 416 | ''' |
||
| 417 | top_k = get_top_k |
||
| 418 | nq = get_nq |
||
| 419 | metric_type = "IP" |
||
| 420 | index_type = get_simple_index["index_type"] |
||
| 421 | if index_type == "IVF_PQ": |
||
| 422 | pytest.skip("Skip PQ") |
||
| 423 | connect.create_partition(collection, tag) |
||
| 424 | entities, ids = init_data(connect, collection) |
||
| 425 | get_simple_index["metric_type"] = metric_type |
||
| 426 | connect.create_index(collection, field_name, get_simple_index) |
||
| 427 | search_param = get_search_param(index_type) |
||
| 428 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=metric_type, |
||
| 429 | search_params=search_param) |
||
| 430 | if top_k > top_k_limit: |
||
| 431 | with pytest.raises(Exception) as e: |
||
| 432 | res = connect.search(collection, query) |
||
| 433 | else: |
||
| 434 | res = connect.search(collection, query) |
||
| 435 | assert len(res) == nq |
||
| 436 | assert len(res[0]) >= top_k |
||
| 437 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
| 438 | assert check_id_result(res[0], ids[0]) |
||
| 439 | res = connect.search(collection, query, partition_tags=[tag]) |
||
| 440 | assert len(res) == nq |
||
| 441 | |||
| 442 | @pytest.mark.level(2) |
||
| 443 | def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k): |
||
| 444 | ''' |
||
| 445 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
| 446 | method: search collection with the given vectors and tags, check the result |
||
| 447 | expected: the length of the result is top_k |
||
| 448 | ''' |
||
| 449 | top_k = get_top_k |
||
| 450 | nq = 2 |
||
| 451 | metric_type = "IP" |
||
| 452 | new_tag = "new_tag" |
||
| 453 | index_type = get_simple_index["index_type"] |
||
| 454 | if index_type == "IVF_PQ": |
||
| 455 | pytest.skip("Skip PQ") |
||
| 456 | connect.create_partition(collection, tag) |
||
| 457 | connect.create_partition(collection, new_tag) |
||
| 458 | entities, ids = init_data(connect, collection, partition_tags=tag) |
||
| 459 | new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) |
||
| 460 | get_simple_index["metric_type"] = metric_type |
||
| 461 | connect.create_index(collection, field_name, get_simple_index) |
||
| 462 | search_param = get_search_param(index_type) |
||
| 463 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param) |
||
| 464 | if top_k > top_k_limit: |
||
| 465 | with pytest.raises(Exception) as e: |
||
| 466 | res = connect.search(collection, query) |
||
| 467 | else: |
||
| 468 | res = connect.search(collection, query) |
||
| 469 | assert check_id_result(res[0], ids[0]) |
||
| 470 | assert not check_id_result(res[1], new_ids[0]) |
||
| 471 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
| 472 | assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0]) |
||
| 473 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
| 474 | assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0]) |
||
| 475 | # TODO: |
||
| 476 | # assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0]) |
||
| 477 | |||
| 478 | @pytest.mark.level(2) |
||
| 479 | def test_search_without_connect(self, dis_connect, collection): |
||
| 480 | ''' |
||
| 481 | target: test search vectors without connection |
||
| 482 | method: use dis connected instance, call search method and check if search successfully |
||
| 483 | expected: raise exception |
||
| 484 | ''' |
||
| 485 | with pytest.raises(Exception) as e: |
||
| 486 | res = dis_connect.search(collection, default_query) |
||
| 487 | |||
| 488 | def test_search_collection_name_not_existed(self, connect): |
||
| 489 | ''' |
||
| 490 | target: search collection not existed |
||
| 491 | method: search with the random collection_name, which is not in db |
||
| 492 | expected: status not ok |
||
| 493 | ''' |
||
| 494 | collection_name = gen_unique_str(collection_id) |
||
| 495 | with pytest.raises(Exception) as e: |
||
| 496 | res = connect.search(collection_name, default_query) |
||
| 497 | |||
| 498 | View Code Duplication | def test_search_distance_l2(self, connect, collection): |
|
| 499 | ''' |
||
| 500 | target: search collection, and check the result: distance |
||
| 501 | method: compare the return distance value with value computed with Euclidean |
||
| 502 | expected: the return distance equals to the computed value |
||
| 503 | ''' |
||
| 504 | nq = 2 |
||
| 505 | search_param = {"nprobe": 1} |
||
| 506 | entities, ids = init_data(connect, collection, nb=nq) |
||
| 507 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param) |
||
| 508 | inside_query, inside_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
| 509 | distance_0 = l2(vecs[0], inside_vecs[0]) |
||
| 510 | distance_1 = l2(vecs[0], inside_vecs[1]) |
||
| 511 | res = connect.search(collection, query) |
||
| 512 | assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) |
||
| 513 | |||
| 514 | # TODO: distance problem |
||
| 515 | View Code Duplication | def _test_search_distance_l2_after_index(self, connect, collection, get_simple_index): |
|
| 516 | ''' |
||
| 517 | target: search collection, and check the result: distance |
||
| 518 | method: compare the return distance value with value computed with Inner product |
||
| 519 | expected: the return distance equals to the computed value |
||
| 520 | ''' |
||
| 521 | index_type = get_simple_index["index_type"] |
||
| 522 | nq = 2 |
||
| 523 | entities, ids = init_data(connect, collection) |
||
| 524 | connect.create_index(collection, field_name, get_simple_index) |
||
| 525 | search_param = get_search_param(index_type) |
||
| 526 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param) |
||
| 527 | inside_vecs = entities[-1]["values"] |
||
| 528 | min_distance = 1.0 |
||
| 529 | for i in range(nb): |
||
| 530 | tmp_dis = l2(vecs[0], inside_vecs[i]) |
||
| 531 | if min_distance > tmp_dis: |
||
| 532 | min_distance = tmp_dis |
||
| 533 | res = connect.search(collection, query) |
||
| 534 | assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= gen_inaccuracy(res[0]._distances[0]) |
||
| 535 | |||
| 536 | View Code Duplication | def test_search_distance_ip(self, connect, collection): |
|
| 537 | ''' |
||
| 538 | target: search collection, and check the result: distance |
||
| 539 | method: compare the return distance value with value computed with Inner product |
||
| 540 | expected: the return distance equals to the computed value |
||
| 541 | ''' |
||
| 542 | nq = 2 |
||
| 543 | metirc_type = "IP" |
||
| 544 | search_param = {"nprobe": 1} |
||
| 545 | entities, ids = init_data(connect, collection, nb=nq) |
||
| 546 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type, |
||
| 547 | search_params=search_param) |
||
| 548 | inside_query, inside_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
| 549 | distance_0 = ip(vecs[0], inside_vecs[0]) |
||
| 550 | distance_1 = ip(vecs[0], inside_vecs[1]) |
||
| 551 | res = connect.search(collection, query) |
||
| 552 | assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) |
||
| 553 | |||
| 554 | # TODO: distance problem |
||
| 555 | View Code Duplication | def _test_search_distance_ip_after_index(self, connect, collection, get_simple_index): |
|
| 556 | ''' |
||
| 557 | target: search collection, and check the result: distance |
||
| 558 | method: compare the return distance value with value computed with Inner product |
||
| 559 | expected: the return distance equals to the computed value |
||
| 560 | ''' |
||
| 561 | index_type = get_simple_index["index_type"] |
||
| 562 | nq = 2 |
||
| 563 | metirc_type = "IP" |
||
| 564 | entities, ids = init_data(connect, collection) |
||
| 565 | get_simple_index["metric_type"] = metirc_type |
||
| 566 | connect.create_index(collection, field_name, get_simple_index) |
||
| 567 | search_param = get_search_param(index_type) |
||
| 568 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type, |
||
| 569 | search_params=search_param) |
||
| 570 | inside_vecs = entities[-1]["values"] |
||
| 571 | max_distance = 0 |
||
| 572 | for i in range(nb): |
||
| 573 | tmp_dis = ip(vecs[0], inside_vecs[i]) |
||
| 574 | if max_distance < tmp_dis: |
||
| 575 | max_distance = tmp_dis |
||
| 576 | res = connect.search(collection, query) |
||
| 577 | assert abs(res[0]._distances[0] - max_distance) <= gen_inaccuracy(res[0]._distances[0]) |
||
| 578 | |||
| 579 | # TODO: |
||
| 580 | def _test_search_distance_jaccard_flat_index(self, connect, binary_collection): |
||
| 581 | ''' |
||
| 582 | target: search binary_collection, and check the result: distance |
||
| 583 | method: compare the return distance value with value computed with Inner product |
||
| 584 | expected: the return distance equals to the computed value |
||
| 585 | ''' |
||
| 586 | # from scipy.spatial import distance |
||
| 587 | nprobe = 512 |
||
| 588 | int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) |
||
| 589 | query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
| 590 | distance_0 = jaccard(query_int_vectors[0], int_vectors[0]) |
||
| 591 | distance_1 = jaccard(query_int_vectors[0], int_vectors[1]) |
||
| 592 | res = connect.search(binary_collection, query_entities) |
||
| 593 | assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon |
||
| 594 | |||
| 595 | def _test_search_distance_hamming_flat_index(self, connect, binary_collection): |
||
| 596 | ''' |
||
| 597 | target: search binary_collection, and check the result: distance |
||
| 598 | method: compare the return distance value with value computed with Inner product |
||
| 599 | expected: the return distance equals to the computed value |
||
| 600 | ''' |
||
| 601 | # from scipy.spatial import distance |
||
| 602 | nprobe = 512 |
||
| 603 | int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) |
||
| 604 | query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
| 605 | distance_0 = hamming(query_int_vectors[0], int_vectors[0]) |
||
| 606 | distance_1 = hamming(query_int_vectors[0], int_vectors[1]) |
||
| 607 | res = connect.search(binary_collection, query_entities) |
||
| 608 | assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon |
||
| 609 | |||
| 610 | View Code Duplication | def _test_search_distance_substructure_flat_index(self, connect, binary_collection): |
|
| 611 | ''' |
||
| 612 | target: search binary_collection, and check the result: distance |
||
| 613 | method: compare the return distance value with value computed with Inner product |
||
| 614 | expected: the return distance equals to the computed value |
||
| 615 | ''' |
||
| 616 | # from scipy.spatial import distance |
||
| 617 | nprobe = 512 |
||
| 618 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
| 619 | index_type = "FLAT" |
||
| 620 | index_param = { |
||
| 621 | "nlist": 16384, |
||
| 622 | "metric_type": "SUBSTRUCTURE" |
||
| 623 | } |
||
| 624 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
| 625 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
| 626 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
| 627 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
| 628 | distance_0 = substructure(query_int_vectors[0], int_vectors[0]) |
||
| 629 | distance_1 = substructure(query_int_vectors[0], int_vectors[1]) |
||
| 630 | search_param = get_search_param(index_type) |
||
| 631 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
| 632 | logging.getLogger().info(status) |
||
| 633 | logging.getLogger().info(result) |
||
| 634 | assert len(result[0]) == 0 |
||
| 635 | |||
| 636 | View Code Duplication | def _test_search_distance_substructure_flat_index_B(self, connect, binary_collection): |
|
| 637 | ''' |
||
| 638 | target: search binary_collection, and check the result: distance |
||
| 639 | method: compare the return distance value with value computed with SUB |
||
| 640 | expected: the return distance equals to the computed value |
||
| 641 | ''' |
||
| 642 | # from scipy.spatial import distance |
||
| 643 | top_k = 3 |
||
| 644 | nprobe = 512 |
||
| 645 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
| 646 | index_type = "FLAT" |
||
| 647 | index_param = { |
||
| 648 | "nlist": 16384, |
||
| 649 | "metric_type": "SUBSTRUCTURE" |
||
| 650 | } |
||
| 651 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
| 652 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
| 653 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
| 654 | query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2) |
||
| 655 | search_param = get_search_param(index_type) |
||
| 656 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
| 657 | logging.getLogger().info(status) |
||
| 658 | logging.getLogger().info(result) |
||
| 659 | assert len(result[0]) == 1 |
||
| 660 | assert len(result[1]) == 1 |
||
| 661 | assert result[0][0].distance <= epsilon |
||
| 662 | assert result[0][0].id == ids[0] |
||
| 663 | assert result[1][0].distance <= epsilon |
||
| 664 | assert result[1][0].id == ids[1] |
||
| 665 | |||
| 666 | View Code Duplication | def _test_search_distance_superstructure_flat_index(self, connect, binary_collection): |
|
| 667 | ''' |
||
| 668 | target: search binary_collection, and check the result: distance |
||
| 669 | method: compare the return distance value with value computed with Inner product |
||
| 670 | expected: the return distance equals to the computed value |
||
| 671 | ''' |
||
| 672 | # from scipy.spatial import distance |
||
| 673 | nprobe = 512 |
||
| 674 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
| 675 | index_type = "FLAT" |
||
| 676 | index_param = { |
||
| 677 | "nlist": 16384, |
||
| 678 | "metric_type": "SUBSTRUCTURE" |
||
| 679 | } |
||
| 680 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
| 681 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
| 682 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
| 683 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
| 684 | distance_0 = superstructure(query_int_vectors[0], int_vectors[0]) |
||
| 685 | distance_1 = superstructure(query_int_vectors[0], int_vectors[1]) |
||
| 686 | search_param = get_search_param(index_type) |
||
| 687 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
| 688 | logging.getLogger().info(status) |
||
| 689 | logging.getLogger().info(result) |
||
| 690 | assert len(result[0]) == 0 |
||
| 691 | |||
| 692 | View Code Duplication | def _test_search_distance_superstructure_flat_index_B(self, connect, binary_collection): |
|
| 693 | ''' |
||
| 694 | target: search binary_collection, and check the result: distance |
||
| 695 | method: compare the return distance value with value computed with SUPER |
||
| 696 | expected: the return distance equals to the computed value |
||
| 697 | ''' |
||
| 698 | # from scipy.spatial import distance |
||
| 699 | top_k = 3 |
||
| 700 | nprobe = 512 |
||
| 701 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
| 702 | index_type = "FLAT" |
||
| 703 | index_param = { |
||
| 704 | "nlist": 16384, |
||
| 705 | "metric_type": "SUBSTRUCTURE" |
||
| 706 | } |
||
| 707 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
| 708 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
| 709 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
| 710 | query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2) |
||
| 711 | search_param = get_search_param(index_type) |
||
| 712 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
| 713 | logging.getLogger().info(status) |
||
| 714 | logging.getLogger().info(result) |
||
| 715 | assert len(result[0]) == 2 |
||
| 716 | assert len(result[1]) == 2 |
||
| 717 | assert result[0][0].id in ids |
||
| 718 | assert result[0][0].distance <= epsilon |
||
| 719 | assert result[1][0].id in ids |
||
| 720 | assert result[1][0].distance <= epsilon |
||
| 721 | |||
| 722 | View Code Duplication | def _test_search_distance_tanimoto_flat_index(self, connect, binary_collection): |
|
| 723 | ''' |
||
| 724 | target: search binary_collection, and check the result: distance |
||
| 725 | method: compare the return distance value with value computed with Inner product |
||
| 726 | expected: the return distance equals to the computed value |
||
| 727 | ''' |
||
| 728 | # from scipy.spatial import distance |
||
| 729 | nprobe = 512 |
||
| 730 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
| 731 | index_type = "FLAT" |
||
| 732 | index_param = { |
||
| 733 | "nlist": 16384, |
||
| 734 | "metric_type": "TANIMOTO" |
||
| 735 | } |
||
| 736 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
| 737 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
| 738 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
| 739 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
| 740 | distance_0 = tanimoto(query_int_vectors[0], int_vectors[0]) |
||
| 741 | distance_1 = tanimoto(query_int_vectors[0], int_vectors[1]) |
||
| 742 | search_param = get_search_param(index_type) |
||
| 743 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
| 744 | logging.getLogger().info(status) |
||
| 745 | logging.getLogger().info(result) |
||
| 746 | assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon |
||
| 747 | |||
| 748 | @pytest.mark.timeout(30) |
||
| 749 | def test_search_concurrent_multithreads(self, connect, args): |
||
| 750 | ''' |
||
| 751 | target: test concurrent search with multiprocessess |
||
| 752 | method: search with 10 processes, each process uses dependent connection |
||
| 753 | expected: status ok and the returned vectors should be query_records |
||
| 754 | ''' |
||
| 755 | nb = 100 |
||
| 756 | top_k = 10 |
||
| 757 | threads_num = 4 |
||
| 758 | threads = [] |
||
| 759 | collection = gen_unique_str(collection_id) |
||
| 760 | uri = "tcp://%s:%s" % (args["ip"], args["port"]) |
||
| 761 | # create collection |
||
| 762 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
| 763 | milvus.create_collection(collection, default_fields) |
||
| 764 | entities, ids = init_data(milvus, collection) |
||
| 765 | |||
| 766 | def search(milvus): |
||
| 767 | res = connect.search(collection, default_query) |
||
| 768 | assert len(res) == 1 |
||
| 769 | assert res[0]._entities[0].id in ids |
||
| 770 | assert res[0]._distances[0] < epsilon |
||
| 771 | |||
| 772 | for i in range(threads_num): |
||
| 773 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
| 774 | t = threading.Thread(target=search, args=(milvus,)) |
||
| 775 | threads.append(t) |
||
| 776 | t.start() |
||
| 777 | time.sleep(0.2) |
||
| 778 | for t in threads: |
||
| 779 | t.join() |
||
| 780 | |||
| 781 | @pytest.mark.timeout(30) |
||
| 782 | def test_search_concurrent_multithreads_single_connection(self, connect, args): |
||
| 783 | ''' |
||
| 784 | target: test concurrent search with multiprocessess |
||
| 785 | method: search with 10 processes, each process uses dependent connection |
||
| 786 | expected: status ok and the returned vectors should be query_records |
||
| 787 | ''' |
||
| 788 | nb = 100 |
||
| 789 | top_k = 10 |
||
| 790 | threads_num = 4 |
||
| 791 | threads = [] |
||
| 792 | collection = gen_unique_str(collection_id) |
||
| 793 | uri = "tcp://%s:%s" % (args["ip"], args["port"]) |
||
| 794 | # create collection |
||
| 795 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
| 796 | milvus.create_collection(collection, default_fields) |
||
| 797 | entities, ids = init_data(milvus, collection) |
||
| 798 | |||
| 799 | def search(milvus): |
||
| 800 | res = connect.search(collection, default_query) |
||
| 801 | assert len(res) == 1 |
||
| 802 | assert res[0]._entities[0].id in ids |
||
| 803 | assert res[0]._distances[0] < epsilon |
||
| 804 | |||
| 805 | for i in range(threads_num): |
||
| 806 | t = threading.Thread(target=search, args=(milvus,)) |
||
| 807 | threads.append(t) |
||
| 808 | t.start() |
||
| 809 | time.sleep(0.2) |
||
| 810 | for t in threads: |
||
| 811 | t.join() |
||
| 812 | |||
| 813 | def test_search_multi_collections(self, connect, args): |
||
| 814 | ''' |
||
| 815 | target: test search multi collections of L2 |
||
| 816 | method: add vectors into 10 collections, and search |
||
| 817 | expected: search status ok, the length of result |
||
| 818 | ''' |
||
| 819 | num = 10 |
||
| 820 | top_k = 10 |
||
| 821 | nq = 20 |
||
| 822 | for i in range(num): |
||
| 823 | collection = gen_unique_str(collection_id + str(i)) |
||
| 824 | connect.create_collection(collection, default_fields) |
||
| 825 | entities, ids = init_data(connect, collection) |
||
| 826 | assert len(ids) == nb |
||
| 827 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
| 828 | res = connect.search(collection, query) |
||
| 829 | assert len(res) == nq |
||
| 830 | for i in range(nq): |
||
| 831 | assert check_id_result(res[i], ids[i]) |
||
| 832 | assert res[i]._distances[0] < epsilon |
||
| 833 | assert res[i]._distances[1] > epsilon |
||
| 834 | |||
| 835 | |||
| 836 | class TestSearchDSL(object): |
||
| 837 | """ |
||
| 838 | ****************************************************************** |
||
| 839 | # The following cases are used to build invalid query expr |
||
| 840 | ****************************************************************** |
||
| 841 | """ |
||
| 842 | |||
| 843 | # TODO: assert exception |
||
| 844 | def test_query_no_must(self, connect, collection): |
||
| 845 | ''' |
||
| 846 | method: build query without must expr |
||
| 847 | expected: error raised |
||
| 848 | ''' |
||
| 849 | # entities, ids = init_data(connect, collection) |
||
| 850 | query = update_query_expr(default_query, keep_old=False) |
||
| 851 | with pytest.raises(Exception) as e: |
||
| 852 | res = connect.search(collection, query) |
||
| 853 | |||
| 854 | # TODO: |
||
| 855 | def test_query_no_vector_term_only(self, connect, collection): |
||
| 856 | ''' |
||
| 857 | method: build query without must expr |
||
| 858 | expected: error raised |
||
| 859 | ''' |
||
| 860 | # entities, ids = init_data(connect, collection) |
||
| 861 | expr = { |
||
| 862 | "must": [gen_default_term_expr] |
||
| 863 | } |
||
| 864 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
| 865 | with pytest.raises(Exception) as e: |
||
| 866 | res = connect.search(collection, query) |
||
| 867 | |||
| 868 | def test_query_vector_only(self, connect, collection): |
||
| 869 | entities, ids = init_data(connect, collection) |
||
| 870 | res = connect.search(collection, default_query) |
||
| 871 | assert len(res) == nq |
||
| 872 | assert len(res[0]) == top_k |
||
| 873 | |||
| 874 | def test_query_wrong_format(self, connect, collection): |
||
| 875 | ''' |
||
| 876 | method: build query without must expr, with wrong expr name |
||
| 877 | expected: error raised |
||
| 878 | ''' |
||
| 879 | # entities, ids = init_data(connect, collection) |
||
| 880 | expr = { |
||
| 881 | "must1": [gen_default_term_expr] |
||
| 882 | } |
||
| 883 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
| 884 | with pytest.raises(Exception) as e: |
||
| 885 | res = connect.search(collection, query) |
||
| 886 | |||
| 887 | def test_query_empty(self, connect, collection): |
||
| 888 | ''' |
||
| 889 | method: search with empty query |
||
| 890 | expected: error raised |
||
| 891 | ''' |
||
| 892 | query = {} |
||
| 893 | with pytest.raises(Exception) as e: |
||
| 894 | res = connect.search(collection, query) |
||
| 895 | |||
| 896 | def test_query_with_wrong_format_term(self, connect, collection): |
||
| 897 | ''' |
||
| 898 | method: build query with wrong term expr |
||
| 899 | expected: error raised |
||
| 900 | ''' |
||
| 901 | expr = gen_default_term_expr |
||
| 902 | expr["term"] = 1 |
||
| 903 | expr = {"must": [expr]} |
||
| 904 | query = update_query_expr(default_query, expr=expr) |
||
| 905 | with pytest.raises(Exception) as e: |
||
| 906 | res = connect.search(collection, query) |
||
| 907 | |||
| 908 | """ |
||
| 909 | ****************************************************************** |
||
| 910 | # The following cases are used to build valid query expr |
||
| 911 | ****************************************************************** |
||
| 912 | """ |
||
| 913 | |||
| 914 | def test_query_term_value_not_in(self, connect, collection): |
||
| 915 | ''' |
||
| 916 | method: build query with vector and term expr, with no term can be filtered |
||
| 917 | expected: filter pass |
||
| 918 | ''' |
||
| 919 | entities, ids = init_data(connect, collection) |
||
| 920 | expr = { |
||
| 921 | "must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[100000])]} |
||
| 922 | query = update_query_expr(default_query, expr=expr) |
||
| 923 | res = connect.search(collection, query) |
||
| 924 | # TODO: |
||
| 925 | |||
| 926 | def test_query_term_value_all_in(self, connect, collection): |
||
| 927 | ''' |
||
| 928 | method: build query with vector and term expr, with all term can be filtered |
||
| 929 | expected: filter pass |
||
| 930 | ''' |
||
| 931 | entities, ids = init_data(connect, collection) |
||
| 932 | expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1])]} |
||
| 933 | query = update_query_expr(default_query, expr=expr) |
||
| 934 | res = connect.search(collection, query) |
||
| 935 | # TODO: |
||
| 936 | |||
| 937 | def test_query_term_values_not_in(self, connect, collection): |
||
| 938 | ''' |
||
| 939 | method: build query with vector and term expr, with no term can be filtered |
||
| 940 | expected: filter pass |
||
| 941 | ''' |
||
| 942 | entities, ids = init_data(connect, collection) |
||
| 943 | expr = {"must": [gen_default_vector_expr(default_query), |
||
| 944 | gen_default_term_expr(values=[i for i in range(100000, 100010)])]} |
||
| 945 | query = update_query_expr(default_query, expr=expr) |
||
| 946 | res = connect.search(collection, query) |
||
| 947 | # TODO: |
||
| 948 | |||
| 949 | def test_query_term_values_all_in(self, connect, collection): |
||
| 950 | ''' |
||
| 951 | method: build query with vector and term expr, with all term can be filtered |
||
| 952 | expected: filter pass |
||
| 953 | ''' |
||
| 954 | entities, ids = init_data(connect, collection) |
||
| 955 | expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr()]} |
||
| 956 | query = update_query_expr(default_query, expr=expr) |
||
| 957 | res = connect.search(collection, query) |
||
| 958 | # TODO: |
||
| 959 | |||
| 960 | View Code Duplication | def test_query_term_values_parts_in(self, connect, collection): |
|
| 961 | ''' |
||
| 962 | method: build query with vector and term expr, with parts of term can be filtered |
||
| 963 | expected: filter pass |
||
| 964 | ''' |
||
| 965 | entities, ids = init_data(connect, collection) |
||
| 966 | expr = {"must": [gen_default_vector_expr(default_query), |
||
| 967 | gen_default_term_expr(values=[i for i in range(nb / 2, nb + nb / 2)])]} |
||
| 968 | query = update_query_expr(default_query, expr=expr) |
||
| 969 | res = connect.search(collection, query) |
||
| 970 | # TODO: |
||
| 971 | |||
| 972 | View Code Duplication | def test_query_term_values_repeat(self, connect, collection): |
|
| 973 | ''' |
||
| 974 | method: build query with vector and term expr, with the same values |
||
| 975 | expected: filter pass |
||
| 976 | ''' |
||
| 977 | entities, ids = init_data(connect, collection) |
||
| 978 | expr = { |
||
| 979 | "must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1 for i in range(1, nb)])]} |
||
| 980 | query = update_query_expr(default_query, expr=expr) |
||
| 981 | res = connect.search(collection, query) |
||
| 982 | # TODO: |
||
| 983 | |||
| 984 | def test_query_term_value_empty(self, connect, collection): |
||
| 985 | ''' |
||
| 986 | method: build query with term value empty |
||
| 987 | expected: return null |
||
| 988 | ''' |
||
| 989 | expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[])]} |
||
| 990 | query = update_query_expr(default_query, expr=expr) |
||
| 991 | logging.getLogger().info(query) |
||
| 992 | res = connect.search(collection, query) |
||
| 993 | logging.getLogger().info(res) |
||
| 994 | assert len(res) == nq |
||
| 995 | assert len(res[0]) == 0 |
||
| 996 | |||
| 997 | # TODO |
||
| 998 | def test_query_term_key_error(self, connect, collection): |
||
| 999 | ''' |
||
| 1000 | method: build query with term key error |
||
| 1001 | expected: error raised |
||
| 1002 | ''' |
||
| 1003 | expr = {"must": [gen_default_vector_expr(default_query), |
||
| 1004 | gen_default_term_expr(keyword="terrm", values=[i for i in range(nb // 2)])]} |
||
| 1005 | query = update_query_expr(default_query, expr=expr) |
||
| 1006 | logging.getLogger().info(query) |
||
| 1007 | with pytest.raises(Exception) as e: |
||
| 1008 | res = connect.search(collection, query) |
||
| 1009 | |||
| 1010 | View Code Duplication | def test_query_term_wrong_format(self, connect, collection): |
|
| 1011 | ''' |
||
| 1012 | method: build query with wrong format term |
||
| 1013 | expected: error raised |
||
| 1014 | ''' |
||
| 1015 | term = {"term": 1} |
||
| 1016 | expr = {"must": [gen_default_vector_expr(default_query), term]} |
||
| 1017 | query = update_query_expr(default_query, expr=expr) |
||
| 1018 | logging.getLogger().info(query) |
||
| 1019 | with pytest.raises(Exception) as e: |
||
| 1020 | res = connect.search(collection, query) |
||
| 1021 | |||
| 1022 | # TODO |
||
| 1023 | View Code Duplication | def test_query_term_wrong_format_null(self, connect, collection): |
|
| 1024 | ''' |
||
| 1025 | method: build query with wrong format term |
||
| 1026 | expected: error raised |
||
| 1027 | ''' |
||
| 1028 | term = {"term": {}} |
||
| 1029 | expr = {"must": [gen_default_vector_expr(default_query), term]} |
||
| 1030 | query = update_query_expr(default_query, expr=expr) |
||
| 1031 | logging.getLogger().info(query) |
||
| 1032 | with pytest.raises(Exception) as e: |
||
| 1033 | res = connect.search(collection, query) |
||
| 1034 | |||
| 1035 | # TODO |
||
| 1036 | def test_query_term_field_named_term(self, connect, collection): |
||
| 1037 | ''' |
||
| 1038 | method: build query with field named "term" |
||
| 1039 | expected: error raised |
||
| 1040 | ''' |
||
| 1041 | term_fields = add_field_default(default_fields, field_name="term") |
||
| 1042 | collection_term = gen_unique_str("term") |
||
| 1043 | connect.create_collection(collection_term, term_fields) |
||
| 1044 | term_entities = add_field(entities, field_name="term") |
||
| 1045 | ids = connect.insert(collection_term, term_entities) |
||
| 1046 | assert len(ids) == nb |
||
| 1047 | connect.flush([collection_term]) |
||
| 1048 | count = connect.count_entities(collection_term) |
||
| 1049 | assert count == nb |
||
| 1050 | term_param = {"term": {"term": {"values": [i for i in range(nb // 2)]}}} |
||
| 1051 | expr = {"must": [gen_default_vector_expr(default_query), |
||
| 1052 | term_param]} |
||
| 1053 | query = update_query_expr(default_query, expr=expr) |
||
| 1054 | logging.getLogger().info(query) |
||
| 1055 | res = connect.search(collection, query) |
||
| 1056 | assert len(res) == nq |
||
| 1057 | assert len(res[0]) == top_k |
||
| 1058 | connect.drop_collection(collection_term) |
||
| 1059 | |||
| 1060 | |||
| 1061 | class TestSearchDSLBools(object): |
||
| 1062 | """ |
||
| 1063 | ****************************************************************** |
||
| 1064 | # The following cases are used to build invalid query expr |
||
| 1065 | ****************************************************************** |
||
| 1066 | """ |
||
| 1067 | |||
| 1068 | def test_query_no_bool(self, connect, collection): |
||
| 1069 | ''' |
||
| 1070 | method: build query without bool expr |
||
| 1071 | expected: error raised |
||
| 1072 | ''' |
||
| 1073 | expr = {"bool1": {}} |
||
| 1074 | with pytest.raises(Exception) as e: |
||
| 1075 | res = connect.search(collection, query) |
||
| 1076 | |||
| 1077 | def test_query_should_only_term(self, connect, collection): |
||
| 1078 | ''' |
||
| 1079 | method: build query without must, with should.term instead |
||
| 1080 | expected: error raised |
||
| 1081 | ''' |
||
| 1082 | expr = {"should": gen_default_term_expr} |
||
| 1083 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
| 1084 | with pytest.raises(Exception) as e: |
||
| 1085 | res = connect.search(collection, query) |
||
| 1086 | |||
| 1087 | def test_query_should_only_vector(self, connect, collection): |
||
| 1088 | ''' |
||
| 1089 | method: build query without must, with should.vector instead |
||
| 1090 | expected: error raised |
||
| 1091 | ''' |
||
| 1092 | expr = {"should": default_query["bool"]["must"]} |
||
| 1093 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
| 1094 | with pytest.raises(Exception) as e: |
||
| 1095 | res = connect.search(collection, query) |
||
| 1096 | |||
| 1097 | def test_query_must_not_only_term(self, connect, collection): |
||
| 1098 | ''' |
||
| 1099 | method: build query without must, with must_not.term instead |
||
| 1100 | expected: error raised |
||
| 1101 | ''' |
||
| 1102 | expr = {"must_not": gen_default_term_expr} |
||
| 1103 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
| 1104 | with pytest.raises(Exception) as e: |
||
| 1105 | res = connect.search(collection, query) |
||
| 1106 | |||
| 1107 | def test_query_must_not_vector(self, connect, collection): |
||
| 1108 | ''' |
||
| 1109 | method: build query without must, with must_not.vector instead |
||
| 1110 | expected: error raised |
||
| 1111 | ''' |
||
| 1112 | expr = {"must_not": default_query["bool"]["must"]} |
||
| 1113 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
| 1114 | with pytest.raises(Exception) as e: |
||
| 1115 | res = connect.search(collection, query) |
||
| 1116 | |||
| 1117 | def test_query_must_should(self, connect, collection): |
||
| 1118 | ''' |
||
| 1119 | method: build query must, and with should.term |
||
| 1120 | expected: error raised |
||
| 1121 | ''' |
||
| 1122 | expr = {"should": gen_default_term_expr} |
||
| 1123 | query = update_query_expr(default_query, keep_old=True, expr=expr) |
||
| 1124 | with pytest.raises(Exception) as e: |
||
| 1125 | res = connect.search(collection, query) |
||
| 1126 | |||
| 1127 | |||
| 1128 | """ |
||
| 1129 | ****************************************************************** |
||
| 1130 | # The following cases are used to test `search` function |
||
| 1131 | # with invalid collection_name, or invalid query expr |
||
| 1132 | ****************************************************************** |
||
| 1133 | """ |
||
| 1134 | |||
| 1135 | |||
| 1136 | class TestSearchInvalid(object): |
||
| 1137 | """ |
||
| 1138 | Test search collection with invalid collection names |
||
| 1139 | """ |
||
| 1140 | |||
| 1141 | @pytest.fixture( |
||
| 1142 | scope="function", |
||
| 1143 | params=gen_invalid_strs() |
||
| 1144 | ) |
||
| 1145 | def get_collection_name(self, request): |
||
| 1146 | yield request.param |
||
| 1147 | |||
| 1148 | @pytest.fixture( |
||
| 1149 | scope="function", |
||
| 1150 | params=gen_invalid_strs() |
||
| 1151 | ) |
||
| 1152 | def get_invalid_tag(self, request): |
||
| 1153 | yield request.param |
||
| 1154 | |||
| 1155 | @pytest.fixture( |
||
| 1156 | scope="function", |
||
| 1157 | params=gen_invalid_strs() |
||
| 1158 | ) |
||
| 1159 | def get_invalid_field(self, request): |
||
| 1160 | yield request.param |
||
| 1161 | |||
| 1162 | @pytest.fixture( |
||
| 1163 | scope="function", |
||
| 1164 | params=gen_simple_index() |
||
| 1165 | ) |
||
| 1166 | def get_simple_index(self, request, connect): |
||
| 1167 | if str(connect._cmd("mode")) == "CPU": |
||
| 1168 | if request.param["index_type"] in index_cpu_not_support(): |
||
| 1169 | pytest.skip("sq8h not support in CPU mode") |
||
| 1170 | return request.param |
||
| 1171 | |||
| 1172 | @pytest.mark.level(2) |
||
| 1173 | def test_search_with_invalid_collection(self, connect, get_collection_name): |
||
| 1174 | collection_name = get_collection_name |
||
| 1175 | with pytest.raises(Exception) as e: |
||
| 1176 | res = connect.search(collection_name, default_query) |
||
| 1177 | |||
| 1178 | @pytest.mark.level(1) |
||
| 1179 | def test_search_with_invalid_tag(self, connect, collection): |
||
| 1180 | tag = " " |
||
| 1181 | with pytest.raises(Exception) as e: |
||
| 1182 | res = connect.search(collection, default_query, partition_tags=tag) |
||
| 1183 | |||
| 1184 | @pytest.mark.level(2) |
||
| 1185 | def test_search_with_invalid_field_name(self, connect, collection, get_invalid_field): |
||
| 1186 | fields = [get_invalid_field] |
||
| 1187 | with pytest.raises(Exception) as e: |
||
| 1188 | res = connect.search(collection, default_query, fields=fields) |
||
| 1189 | |||
| 1190 | @pytest.mark.level(1) |
||
| 1191 | def test_search_with_not_existed_field_name(self, connect, collection): |
||
| 1192 | fields = [gen_unique_str("field_name")] |
||
| 1193 | with pytest.raises(Exception) as e: |
||
| 1194 | res = connect.search(collection, default_query, fields=fields) |
||
| 1195 | |||
| 1196 | """ |
||
| 1197 | Test search collection with invalid query |
||
| 1198 | """ |
||
| 1199 | |||
| 1200 | @pytest.fixture( |
||
| 1201 | scope="function", |
||
| 1202 | params=gen_invalid_ints() |
||
| 1203 | ) |
||
| 1204 | def get_top_k(self, request): |
||
| 1205 | yield request.param |
||
| 1206 | |||
| 1207 | @pytest.mark.level(1) |
||
| 1208 | def test_search_with_invalid_top_k(self, connect, collection, get_top_k): |
||
| 1209 | ''' |
||
| 1210 | target: test search fuction, with the wrong top_k |
||
| 1211 | method: search with top_k |
||
| 1212 | expected: raise an error, and the connection is normal |
||
| 1213 | ''' |
||
| 1214 | top_k = get_top_k |
||
| 1215 | default_query["bool"]["must"][0]["vector"][field_name]["topk"] = top_k |
||
| 1216 | with pytest.raises(Exception) as e: |
||
| 1217 | res = connect.search(collection, default_query) |
||
| 1218 | |||
| 1219 | """ |
||
| 1220 | Test search collection with invalid search params |
||
| 1221 | """ |
||
| 1222 | |||
| 1223 | @pytest.fixture( |
||
| 1224 | scope="function", |
||
| 1225 | params=gen_invaild_search_params() |
||
| 1226 | ) |
||
| 1227 | def get_search_params(self, request): |
||
| 1228 | yield request.param |
||
| 1229 | |||
| 1230 | # TODO: This case can all pass, but it's too slow |
||
| 1231 | View Code Duplication | @pytest.mark.level(2) |
|
| 1232 | def _test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params): |
||
| 1233 | ''' |
||
| 1234 | target: test search fuction, with the wrong nprobe |
||
| 1235 | method: search with nprobe |
||
| 1236 | expected: raise an error, and the connection is normal |
||
| 1237 | ''' |
||
| 1238 | search_params = get_search_params |
||
| 1239 | index_type = get_simple_index["index_type"] |
||
| 1240 | entities, ids = init_data(connect, collection) |
||
| 1241 | connect.create_index(collection, field_name, get_simple_index) |
||
| 1242 | if search_params["index_type"] != index_type: |
||
| 1243 | pytest.skip("Skip case") |
||
| 1244 | query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params=search_params["search_params"]) |
||
| 1245 | with pytest.raises(Exception) as e: |
||
| 1246 | res = connect.search(collection, query) |
||
| 1247 | |||
| 1248 | View Code Duplication | def test_search_with_empty_params(self, connect, collection, args, get_simple_index): |
|
| 1249 | ''' |
||
| 1250 | target: test search fuction, with empty search params |
||
| 1251 | method: search with params |
||
| 1252 | expected: raise an error, and the connection is normal |
||
| 1253 | ''' |
||
| 1254 | index_type = get_simple_index["index_type"] |
||
| 1255 | if args["handler"] == "HTTP": |
||
| 1256 | pytest.skip("skip in http mode") |
||
| 1257 | if index_type == "FLAT": |
||
| 1258 | pytest.skip("skip in FLAT index") |
||
| 1259 | entities, ids = init_data(connect, collection) |
||
| 1260 | connect.create_index(collection, field_name, get_simple_index) |
||
| 1261 | query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params={}) |
||
| 1262 | with pytest.raises(Exception) as e: |
||
| 1263 | res = connect.search(collection, query) |
||
| 1264 | |||
| 1265 | |||
| 1266 | def check_id_result(result, id): |
||
| 1267 | limit_in = 5 |
||
| 1268 | ids = [entity.id for entity in result] |
||
| 1269 | if len(result) >= limit_in: |
||
| 1270 | return id in ids[:limit_in] |
||
| 1271 | else: |
||
| 1272 | return id in ids |
||
| 1273 |