Total Complexity | 162 |
Total Lines | 1342 |
Duplicated Lines | 25.63 % |
Changes | 0 |
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like test_search often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | import time |
||
2 | import pdb |
||
3 | import copy |
||
4 | import threading |
||
5 | import logging |
||
6 | from multiprocessing import Pool, Process |
||
7 | import pytest |
||
8 | import numpy as np |
||
9 | |||
10 | from milvus import DataType |
||
11 | from utils import * |
||
12 | |||
13 | dim = 128 |
||
14 | segment_row_count = 5000 |
||
15 | top_k_limit = 2048 |
||
16 | collection_id = "search" |
||
17 | tag = "1970-01-01" |
||
18 | insert_interval_time = 1.5 |
||
19 | nb = 6000 |
||
20 | top_k = 10 |
||
21 | nq = 1 |
||
22 | nprobe = 1 |
||
23 | epsilon = 0.001 |
||
24 | field_name = default_float_vec_field_name |
||
25 | default_fields = gen_default_fields() |
||
26 | search_param = {"nprobe": 1} |
||
27 | entity = gen_entities(1, is_normal=True) |
||
28 | raw_vector, binary_entity = gen_binary_entities(1) |
||
29 | entities = gen_entities(nb, is_normal=True) |
||
30 | raw_vectors, binary_entities = gen_binary_entities(nb) |
||
31 | default_query, default_query_vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
32 | |||
33 | |||
34 | def init_data(connect, collection, nb=6000, partition_tags=None): |
||
35 | ''' |
||
36 | Generate entities and add it in collection |
||
37 | ''' |
||
38 | global entities |
||
39 | if nb == 6000: |
||
40 | insert_entities = entities |
||
41 | else: |
||
42 | insert_entities = gen_entities(nb, is_normal=True) |
||
43 | if partition_tags is None: |
||
44 | ids = connect.insert(collection, insert_entities) |
||
45 | else: |
||
46 | ids = connect.insert(collection, insert_entities, partition_tag=partition_tags) |
||
47 | connect.flush([collection]) |
||
48 | return insert_entities, ids |
||
49 | |||
50 | |||
51 | def init_binary_data(connect, collection, nb=6000, insert=True, partition_tags=None): |
||
52 | ''' |
||
53 | Generate entities and add it in collection |
||
54 | ''' |
||
55 | ids = [] |
||
56 | global binary_entities |
||
57 | global raw_vectors |
||
58 | if nb == 6000: |
||
59 | insert_entities = binary_entities |
||
60 | insert_raw_vectors = raw_vectors |
||
61 | else: |
||
62 | insert_raw_vectors, insert_entities = gen_binary_entities(nb) |
||
63 | if insert is True: |
||
64 | if partition_tags is None: |
||
65 | ids = connect.insert(collection, insert_entities) |
||
66 | else: |
||
67 | ids = connect.insert(collection, insert_entities, partition_tag=partition_tags) |
||
68 | connect.flush([collection]) |
||
69 | return insert_raw_vectors, insert_entities, ids |
||
70 | |||
71 | |||
72 | class TestSearchBase: |
||
73 | """ |
||
74 | generate valid create_index params |
||
75 | """ |
||
76 | |||
77 | @pytest.fixture( |
||
78 | scope="function", |
||
79 | params=gen_index() |
||
80 | ) |
||
81 | def get_index(self, request, connect): |
||
82 | if str(connect._cmd("mode")) == "CPU": |
||
83 | if request.param["index_type"] in index_cpu_not_support(): |
||
84 | pytest.skip("sq8h not support in CPU mode") |
||
85 | return request.param |
||
86 | |||
87 | @pytest.fixture( |
||
88 | scope="function", |
||
89 | params=gen_simple_index() |
||
90 | ) |
||
91 | def get_simple_index(self, request, connect): |
||
92 | if str(connect._cmd("mode")) == "CPU": |
||
93 | if request.param["index_type"] in index_cpu_not_support(): |
||
94 | pytest.skip("sq8h not support in CPU mode") |
||
95 | return request.param |
||
96 | |||
97 | @pytest.fixture( |
||
98 | scope="function", |
||
99 | params=gen_simple_index() |
||
100 | ) |
||
101 | def get_jaccard_index(self, request, connect): |
||
102 | logging.getLogger().info(request.param) |
||
103 | if request.param["index_type"] in binary_support(): |
||
104 | return request.param |
||
105 | else: |
||
106 | pytest.skip("Skip index Temporary") |
||
107 | |||
108 | @pytest.fixture( |
||
109 | scope="function", |
||
110 | params=gen_simple_index() |
||
111 | ) |
||
112 | def get_hamming_index(self, request, connect): |
||
113 | logging.getLogger().info(request.param) |
||
114 | if request.param["index_type"] in binary_support(): |
||
115 | return request.param |
||
116 | else: |
||
117 | pytest.skip("Skip index Temporary") |
||
118 | |||
119 | @pytest.fixture( |
||
120 | scope="function", |
||
121 | params=gen_simple_index() |
||
122 | ) |
||
123 | def get_structure_index(self, request, connect): |
||
124 | logging.getLogger().info(request.param) |
||
125 | if request.param["index_type"] == "FLAT": |
||
126 | return request.param |
||
127 | else: |
||
128 | pytest.skip("Skip index Temporary") |
||
129 | |||
130 | """ |
||
131 | generate top-k params |
||
132 | """ |
||
133 | |||
134 | @pytest.fixture( |
||
135 | scope="function", |
||
136 | params=[1, 10, 2049] |
||
137 | ) |
||
138 | def get_top_k(self, request): |
||
139 | yield request.param |
||
140 | |||
141 | @pytest.fixture( |
||
142 | scope="function", |
||
143 | params=[1, 10, 1100] |
||
144 | ) |
||
145 | def get_nq(self, request): |
||
146 | yield request.param |
||
147 | |||
148 | def test_search_flat(self, connect, collection, get_top_k, get_nq): |
||
149 | ''' |
||
150 | target: test basic search fuction, all the search params is corrent, change top-k value |
||
151 | method: search with the given vectors, check the result |
||
152 | expected: the length of the result is top_k |
||
153 | ''' |
||
154 | top_k = get_top_k |
||
155 | nq = get_nq |
||
156 | entities, ids = init_data(connect, collection) |
||
157 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
158 | if top_k <= top_k_limit: |
||
159 | res = connect.search(collection, query) |
||
160 | assert len(res[0]) == top_k |
||
161 | assert res[0]._distances[0] <= epsilon |
||
162 | assert check_id_result(res[0], ids[0]) |
||
163 | else: |
||
164 | with pytest.raises(Exception) as e: |
||
165 | res = connect.search(collection, query) |
||
166 | |||
167 | def test_search_field(self, connect, collection, get_top_k, get_nq): |
||
168 | ''' |
||
169 | target: test basic search fuction, all the search params is corrent, change top-k value |
||
170 | method: search with the given vectors, check the result |
||
171 | expected: the length of the result is top_k |
||
172 | ''' |
||
173 | top_k = get_top_k |
||
174 | nq = get_nq |
||
175 | entities, ids = init_data(connect, collection) |
||
176 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
177 | if top_k <= top_k_limit: |
||
178 | res = connect.search(collection, query, fields=["float_vector"]) |
||
179 | assert len(res[0]) == top_k |
||
180 | assert res[0]._distances[0] <= epsilon |
||
181 | assert check_id_result(res[0], ids[0]) |
||
182 | # TODO |
||
183 | res = connect.search(collection, query, fields=["float"]) |
||
184 | # TODO |
||
185 | else: |
||
186 | with pytest.raises(Exception) as e: |
||
187 | res = connect.search(collection, query) |
||
188 | |||
189 | @pytest.mark.level(2) |
||
190 | def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
191 | ''' |
||
192 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
193 | method: search with the given vectors, check the result |
||
194 | expected: the length of the result is top_k |
||
195 | ''' |
||
196 | top_k = get_top_k |
||
197 | nq = get_nq |
||
198 | |||
199 | index_type = get_simple_index["index_type"] |
||
200 | if index_type == "IVF_PQ": |
||
201 | pytest.skip("Skip PQ") |
||
202 | entities, ids = init_data(connect, collection) |
||
203 | connect.create_index(collection, field_name, get_simple_index) |
||
204 | search_param = get_search_param(index_type) |
||
205 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
206 | if top_k > top_k_limit: |
||
207 | with pytest.raises(Exception) as e: |
||
208 | res = connect.search(collection, query) |
||
209 | else: |
||
210 | res = connect.search(collection, query) |
||
211 | assert len(res) == nq |
||
212 | assert len(res[0]) >= top_k |
||
213 | assert res[0]._distances[0] < epsilon |
||
214 | assert check_id_result(res[0], ids[0]) |
||
215 | |||
216 | @pytest.mark.level(2) |
||
217 | def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
218 | ''' |
||
219 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
220 | method: add vectors into collection, search with the given vectors, check the result |
||
221 | expected: the length of the result is top_k, search collection with partition tag return empty |
||
222 | ''' |
||
223 | top_k = get_top_k |
||
224 | nq = get_nq |
||
225 | |||
226 | index_type = get_simple_index["index_type"] |
||
227 | if index_type == "IVF_PQ": |
||
228 | pytest.skip("Skip PQ") |
||
229 | connect.create_partition(collection, tag) |
||
230 | entities, ids = init_data(connect, collection) |
||
231 | connect.create_index(collection, field_name, get_simple_index) |
||
232 | search_param = get_search_param(index_type) |
||
233 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
234 | if top_k > top_k_limit: |
||
235 | with pytest.raises(Exception) as e: |
||
236 | res = connect.search(collection, query) |
||
237 | else: |
||
238 | res = connect.search(collection, query) |
||
239 | assert len(res) == nq |
||
240 | assert len(res[0]) >= top_k |
||
241 | assert res[0]._distances[0] < epsilon |
||
242 | assert check_id_result(res[0], ids[0]) |
||
243 | res = connect.search(collection, query, partition_tags=[tag]) |
||
244 | assert len(res) == nq |
||
245 | |||
246 | @pytest.mark.level(2) |
||
247 | def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
248 | ''' |
||
249 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
250 | method: search with the given vectors, check the result |
||
251 | expected: the length of the result is top_k |
||
252 | ''' |
||
253 | top_k = get_top_k |
||
254 | nq = get_nq |
||
255 | |||
256 | index_type = get_simple_index["index_type"] |
||
257 | if index_type == "IVF_PQ": |
||
258 | pytest.skip("Skip PQ") |
||
259 | connect.create_partition(collection, tag) |
||
260 | entities, ids = init_data(connect, collection, partition_tags=tag) |
||
261 | connect.create_index(collection, field_name, get_simple_index) |
||
262 | search_param = get_search_param(index_type) |
||
263 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
264 | for tags in [[tag], [tag, "new_tag"]]: |
||
265 | if top_k > top_k_limit: |
||
266 | with pytest.raises(Exception) as e: |
||
267 | res = connect.search(collection, query, partition_tags=tags) |
||
268 | else: |
||
269 | res = connect.search(collection, query, partition_tags=tags) |
||
270 | assert len(res) == nq |
||
271 | assert len(res[0]) >= top_k |
||
272 | assert res[0]._distances[0] < epsilon |
||
273 | assert check_id_result(res[0], ids[0]) |
||
274 | |||
275 | @pytest.mark.level(2) |
||
276 | def test_search_index_partition_C(self, connect, collection, get_top_k, get_nq): |
||
277 | ''' |
||
278 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
279 | method: search with the given vectors and tag (tag name not existed in collection), check the result |
||
280 | expected: error raised |
||
281 | ''' |
||
282 | top_k = get_top_k |
||
283 | nq = get_nq |
||
284 | entities, ids = init_data(connect, collection) |
||
285 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
286 | if top_k > top_k_limit: |
||
287 | with pytest.raises(Exception) as e: |
||
288 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
289 | else: |
||
290 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
291 | assert len(res) == nq |
||
292 | assert len(res[0]) == 0 |
||
293 | |||
294 | View Code Duplication | @pytest.mark.level(2) |
|
|
|||
295 | def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k): |
||
296 | ''' |
||
297 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
298 | method: search collection with the given vectors and tags, check the result |
||
299 | expected: the length of the result is top_k |
||
300 | ''' |
||
301 | top_k = get_top_k |
||
302 | nq = 2 |
||
303 | new_tag = "new_tag" |
||
304 | index_type = get_simple_index["index_type"] |
||
305 | if index_type == "IVF_PQ": |
||
306 | pytest.skip("Skip PQ") |
||
307 | connect.create_partition(collection, tag) |
||
308 | connect.create_partition(collection, new_tag) |
||
309 | entities, ids = init_data(connect, collection, partition_tags=tag) |
||
310 | new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) |
||
311 | connect.create_index(collection, field_name, get_simple_index) |
||
312 | search_param = get_search_param(index_type) |
||
313 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
314 | if top_k > top_k_limit: |
||
315 | with pytest.raises(Exception) as e: |
||
316 | res = connect.search(collection, query) |
||
317 | else: |
||
318 | res = connect.search(collection, query) |
||
319 | assert check_id_result(res[0], ids[0]) |
||
320 | assert not check_id_result(res[1], new_ids[0]) |
||
321 | assert res[0]._distances[0] < epsilon |
||
322 | assert res[1]._distances[0] < epsilon |
||
323 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
324 | assert res[0]._distances[0] > epsilon |
||
325 | assert res[1]._distances[0] > epsilon |
||
326 | |||
327 | # TODO: |
||
328 | View Code Duplication | @pytest.mark.level(2) |
|
329 | def _test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k): |
||
330 | ''' |
||
331 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
332 | method: search collection with the given vectors and tags, check the result |
||
333 | expected: the length of the result is top_k |
||
334 | ''' |
||
335 | top_k = get_top_k |
||
336 | nq = 2 |
||
337 | tag = "tag" |
||
338 | new_tag = "new_tag" |
||
339 | index_type = get_simple_index["index_type"] |
||
340 | if index_type == "IVF_PQ": |
||
341 | pytest.skip("Skip PQ") |
||
342 | connect.create_partition(collection, tag) |
||
343 | connect.create_partition(collection, new_tag) |
||
344 | entities, ids = init_data(connect, collection, partition_tags=tag) |
||
345 | new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) |
||
346 | connect.create_index(collection, field_name, get_simple_index) |
||
347 | search_param = get_search_param(index_type) |
||
348 | query, vecs = gen_query_vectors(field_name, new_entities, top_k, nq, search_params=search_param) |
||
349 | if top_k > top_k_limit: |
||
350 | with pytest.raises(Exception) as e: |
||
351 | res = connect.search(collection, query) |
||
352 | else: |
||
353 | res = connect.search(collection, query, partition_tags=["(.*)tag"]) |
||
354 | assert not check_id_result(res[0], ids[0]) |
||
355 | assert check_id_result(res[1], new_ids[0]) |
||
356 | assert res[0]._distances[0] > epsilon |
||
357 | assert res[1]._distances[0] < epsilon |
||
358 | res = connect.search(collection, query, partition_tags=["new(.*)"]) |
||
359 | assert res[0]._distances[0] > epsilon |
||
360 | assert res[1]._distances[0] < epsilon |
||
361 | |||
362 | # |
||
363 | # test for ip metric |
||
364 | # |
||
365 | @pytest.mark.level(2) |
||
366 | def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
367 | ''' |
||
368 | target: test basic search fuction, all the search params is corrent, change top-k value |
||
369 | method: search with the given vectors, check the result |
||
370 | expected: the length of the result is top_k |
||
371 | ''' |
||
372 | top_k = get_top_k |
||
373 | nq = get_nq |
||
374 | entities, ids = init_data(connect, collection) |
||
375 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP") |
||
376 | if top_k <= top_k_limit: |
||
377 | res = connect.search(collection, query) |
||
378 | assert len(res[0]) == top_k |
||
379 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
380 | assert check_id_result(res[0], ids[0]) |
||
381 | else: |
||
382 | with pytest.raises(Exception) as e: |
||
383 | res = connect.search(collection, query) |
||
384 | |||
385 | @pytest.mark.level(2) |
||
386 | def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
387 | ''' |
||
388 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
389 | method: search with the given vectors, check the result |
||
390 | expected: the length of the result is top_k |
||
391 | ''' |
||
392 | top_k = get_top_k |
||
393 | nq = get_nq |
||
394 | |||
395 | index_type = get_simple_index["index_type"] |
||
396 | if index_type == "IVF_PQ": |
||
397 | pytest.skip("Skip PQ") |
||
398 | entities, ids = init_data(connect, collection) |
||
399 | get_simple_index["metric_type"] = "IP" |
||
400 | connect.create_index(collection, field_name, get_simple_index) |
||
401 | search_param = get_search_param(index_type) |
||
402 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param) |
||
403 | if top_k > top_k_limit: |
||
404 | with pytest.raises(Exception) as e: |
||
405 | res = connect.search(collection, query) |
||
406 | else: |
||
407 | res = connect.search(collection, query) |
||
408 | assert len(res) == nq |
||
409 | assert len(res[0]) >= top_k |
||
410 | assert check_id_result(res[0], ids[0]) |
||
411 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
412 | |||
413 | @pytest.mark.level(2) |
||
414 | def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
415 | ''' |
||
416 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
417 | method: add vectors into collection, search with the given vectors, check the result |
||
418 | expected: the length of the result is top_k, search collection with partition tag return empty |
||
419 | ''' |
||
420 | top_k = get_top_k |
||
421 | nq = get_nq |
||
422 | metric_type = "IP" |
||
423 | index_type = get_simple_index["index_type"] |
||
424 | if index_type == "IVF_PQ": |
||
425 | pytest.skip("Skip PQ") |
||
426 | connect.create_partition(collection, tag) |
||
427 | entities, ids = init_data(connect, collection) |
||
428 | get_simple_index["metric_type"] = metric_type |
||
429 | connect.create_index(collection, field_name, get_simple_index) |
||
430 | search_param = get_search_param(index_type) |
||
431 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=metric_type, |
||
432 | search_params=search_param) |
||
433 | if top_k > top_k_limit: |
||
434 | with pytest.raises(Exception) as e: |
||
435 | res = connect.search(collection, query) |
||
436 | else: |
||
437 | res = connect.search(collection, query) |
||
438 | assert len(res) == nq |
||
439 | assert len(res[0]) >= top_k |
||
440 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
441 | assert check_id_result(res[0], ids[0]) |
||
442 | res = connect.search(collection, query, partition_tags=[tag]) |
||
443 | assert len(res) == nq |
||
444 | |||
445 | @pytest.mark.level(2) |
||
446 | def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k): |
||
447 | ''' |
||
448 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
449 | method: search collection with the given vectors and tags, check the result |
||
450 | expected: the length of the result is top_k |
||
451 | ''' |
||
452 | top_k = get_top_k |
||
453 | nq = 2 |
||
454 | metric_type = "IP" |
||
455 | new_tag = "new_tag" |
||
456 | index_type = get_simple_index["index_type"] |
||
457 | if index_type == "IVF_PQ": |
||
458 | pytest.skip("Skip PQ") |
||
459 | connect.create_partition(collection, tag) |
||
460 | connect.create_partition(collection, new_tag) |
||
461 | entities, ids = init_data(connect, collection, partition_tags=tag) |
||
462 | new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) |
||
463 | get_simple_index["metric_type"] = metric_type |
||
464 | connect.create_index(collection, field_name, get_simple_index) |
||
465 | search_param = get_search_param(index_type) |
||
466 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param) |
||
467 | if top_k > top_k_limit: |
||
468 | with pytest.raises(Exception) as e: |
||
469 | res = connect.search(collection, query) |
||
470 | else: |
||
471 | res = connect.search(collection, query) |
||
472 | assert check_id_result(res[0], ids[0]) |
||
473 | assert not check_id_result(res[1], new_ids[0]) |
||
474 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
475 | assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0]) |
||
476 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
477 | assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0]) |
||
478 | # TODO: |
||
479 | # assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0]) |
||
480 | |||
481 | @pytest.mark.level(2) |
||
482 | def test_search_without_connect(self, dis_connect, collection): |
||
483 | ''' |
||
484 | target: test search vectors without connection |
||
485 | method: use dis connected instance, call search method and check if search successfully |
||
486 | expected: raise exception |
||
487 | ''' |
||
488 | with pytest.raises(Exception) as e: |
||
489 | res = dis_connect.search(collection, default_query) |
||
490 | |||
491 | def test_search_collection_name_not_existed(self, connect): |
||
492 | ''' |
||
493 | target: search collection not existed |
||
494 | method: search with the random collection_name, which is not in db |
||
495 | expected: status not ok |
||
496 | ''' |
||
497 | collection_name = gen_unique_str(collection_id) |
||
498 | with pytest.raises(Exception) as e: |
||
499 | res = connect.search(collection_name, default_query) |
||
500 | |||
501 | View Code Duplication | def test_search_distance_l2(self, connect, collection): |
|
502 | ''' |
||
503 | target: search collection, and check the result: distance |
||
504 | method: compare the return distance value with value computed with Euclidean |
||
505 | expected: the return distance equals to the computed value |
||
506 | ''' |
||
507 | nq = 2 |
||
508 | search_param = {"nprobe": 1} |
||
509 | entities, ids = init_data(connect, collection, nb=nq) |
||
510 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param) |
||
511 | inside_query, inside_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
512 | distance_0 = l2(vecs[0], inside_vecs[0]) |
||
513 | distance_1 = l2(vecs[0], inside_vecs[1]) |
||
514 | res = connect.search(collection, query) |
||
515 | assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) |
||
516 | |||
517 | # TODO: distance problem |
||
518 | View Code Duplication | def _test_search_distance_l2_after_index(self, connect, collection, get_simple_index): |
|
519 | ''' |
||
520 | target: search collection, and check the result: distance |
||
521 | method: compare the return distance value with value computed with Inner product |
||
522 | expected: the return distance equals to the computed value |
||
523 | ''' |
||
524 | index_type = get_simple_index["index_type"] |
||
525 | nq = 2 |
||
526 | entities, ids = init_data(connect, collection) |
||
527 | connect.create_index(collection, field_name, get_simple_index) |
||
528 | search_param = get_search_param(index_type) |
||
529 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param) |
||
530 | inside_vecs = entities[-1]["values"] |
||
531 | min_distance = 1.0 |
||
532 | for i in range(nb): |
||
533 | tmp_dis = l2(vecs[0], inside_vecs[i]) |
||
534 | if min_distance > tmp_dis: |
||
535 | min_distance = tmp_dis |
||
536 | res = connect.search(collection, query) |
||
537 | assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= gen_inaccuracy(res[0]._distances[0]) |
||
538 | |||
539 | # TODO |
||
540 | View Code Duplication | @pytest.mark.level(2) |
|
541 | def test_search_distance_ip(self, connect, collection): |
||
542 | ''' |
||
543 | target: search collection, and check the result: distance |
||
544 | method: compare the return distance value with value computed with Inner product |
||
545 | expected: the return distance equals to the computed value |
||
546 | ''' |
||
547 | nq = 2 |
||
548 | metirc_type = "IP" |
||
549 | search_param = {"nprobe": 1} |
||
550 | entities, ids = init_data(connect, collection, nb=nq) |
||
551 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type, |
||
552 | search_params=search_param) |
||
553 | inside_query, inside_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
554 | distance_0 = ip(vecs[0], inside_vecs[0]) |
||
555 | distance_1 = ip(vecs[0], inside_vecs[1]) |
||
556 | res = connect.search(collection, query) |
||
557 | assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) |
||
558 | |||
559 | # TODO: distance problem |
||
560 | View Code Duplication | def _test_search_distance_ip_after_index(self, connect, collection, get_simple_index): |
|
561 | ''' |
||
562 | target: search collection, and check the result: distance |
||
563 | method: compare the return distance value with value computed with Inner product |
||
564 | expected: the return distance equals to the computed value |
||
565 | ''' |
||
566 | index_type = get_simple_index["index_type"] |
||
567 | nq = 2 |
||
568 | metirc_type = "IP" |
||
569 | entities, ids = init_data(connect, collection) |
||
570 | get_simple_index["metric_type"] = metirc_type |
||
571 | connect.create_index(collection, field_name, get_simple_index) |
||
572 | search_param = get_search_param(index_type) |
||
573 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type, |
||
574 | search_params=search_param) |
||
575 | inside_vecs = entities[-1]["values"] |
||
576 | max_distance = 0 |
||
577 | for i in range(nb): |
||
578 | tmp_dis = ip(vecs[0], inside_vecs[i]) |
||
579 | if max_distance < tmp_dis: |
||
580 | max_distance = tmp_dis |
||
581 | res = connect.search(collection, query) |
||
582 | assert abs(res[0]._distances[0] - max_distance) <= gen_inaccuracy(res[0]._distances[0]) |
||
583 | |||
584 | # TODO: |
||
585 | def _test_search_distance_jaccard_flat_index(self, connect, binary_collection): |
||
586 | ''' |
||
587 | target: search binary_collection, and check the result: distance |
||
588 | method: compare the return distance value with value computed with Inner product |
||
589 | expected: the return distance equals to the computed value |
||
590 | ''' |
||
591 | # from scipy.spatial import distance |
||
592 | nprobe = 512 |
||
593 | int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) |
||
594 | query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
595 | distance_0 = jaccard(query_int_vectors[0], int_vectors[0]) |
||
596 | distance_1 = jaccard(query_int_vectors[0], int_vectors[1]) |
||
597 | res = connect.search(binary_collection, query_entities) |
||
598 | assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon |
||
599 | |||
600 | def _test_search_distance_hamming_flat_index(self, connect, binary_collection): |
||
601 | ''' |
||
602 | target: search binary_collection, and check the result: distance |
||
603 | method: compare the return distance value with value computed with Inner product |
||
604 | expected: the return distance equals to the computed value |
||
605 | ''' |
||
606 | # from scipy.spatial import distance |
||
607 | nprobe = 512 |
||
608 | int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) |
||
609 | query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
610 | distance_0 = hamming(query_int_vectors[0], int_vectors[0]) |
||
611 | distance_1 = hamming(query_int_vectors[0], int_vectors[1]) |
||
612 | res = connect.search(binary_collection, query_entities) |
||
613 | assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon |
||
614 | |||
615 | View Code Duplication | def _test_search_distance_substructure_flat_index(self, connect, binary_collection): |
|
616 | ''' |
||
617 | target: search binary_collection, and check the result: distance |
||
618 | method: compare the return distance value with value computed with Inner product |
||
619 | expected: the return distance equals to the computed value |
||
620 | ''' |
||
621 | # from scipy.spatial import distance |
||
622 | nprobe = 512 |
||
623 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
624 | index_type = "FLAT" |
||
625 | index_param = { |
||
626 | "nlist": 16384, |
||
627 | "metric_type": "SUBSTRUCTURE" |
||
628 | } |
||
629 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
630 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
631 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
632 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
633 | distance_0 = substructure(query_int_vectors[0], int_vectors[0]) |
||
634 | distance_1 = substructure(query_int_vectors[0], int_vectors[1]) |
||
635 | search_param = get_search_param(index_type) |
||
636 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
637 | logging.getLogger().info(status) |
||
638 | logging.getLogger().info(result) |
||
639 | assert len(result[0]) == 0 |
||
640 | |||
641 | View Code Duplication | def _test_search_distance_substructure_flat_index_B(self, connect, binary_collection): |
|
642 | ''' |
||
643 | target: search binary_collection, and check the result: distance |
||
644 | method: compare the return distance value with value computed with SUB |
||
645 | expected: the return distance equals to the computed value |
||
646 | ''' |
||
647 | # from scipy.spatial import distance |
||
648 | top_k = 3 |
||
649 | nprobe = 512 |
||
650 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
651 | index_type = "FLAT" |
||
652 | index_param = { |
||
653 | "nlist": 16384, |
||
654 | "metric_type": "SUBSTRUCTURE" |
||
655 | } |
||
656 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
657 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
658 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
659 | query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2) |
||
660 | search_param = get_search_param(index_type) |
||
661 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
662 | logging.getLogger().info(status) |
||
663 | logging.getLogger().info(result) |
||
664 | assert len(result[0]) == 1 |
||
665 | assert len(result[1]) == 1 |
||
666 | assert result[0][0].distance <= epsilon |
||
667 | assert result[0][0].id == ids[0] |
||
668 | assert result[1][0].distance <= epsilon |
||
669 | assert result[1][0].id == ids[1] |
||
670 | |||
671 | View Code Duplication | def _test_search_distance_superstructure_flat_index(self, connect, binary_collection): |
|
672 | ''' |
||
673 | target: search binary_collection, and check the result: distance |
||
674 | method: compare the return distance value with value computed with Inner product |
||
675 | expected: the return distance equals to the computed value |
||
676 | ''' |
||
677 | # from scipy.spatial import distance |
||
678 | nprobe = 512 |
||
679 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
680 | index_type = "FLAT" |
||
681 | index_param = { |
||
682 | "nlist": 16384, |
||
683 | "metric_type": "SUBSTRUCTURE" |
||
684 | } |
||
685 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
686 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
687 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
688 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
689 | distance_0 = superstructure(query_int_vectors[0], int_vectors[0]) |
||
690 | distance_1 = superstructure(query_int_vectors[0], int_vectors[1]) |
||
691 | search_param = get_search_param(index_type) |
||
692 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
693 | logging.getLogger().info(status) |
||
694 | logging.getLogger().info(result) |
||
695 | assert len(result[0]) == 0 |
||
696 | |||
697 | View Code Duplication | def _test_search_distance_superstructure_flat_index_B(self, connect, binary_collection): |
|
698 | ''' |
||
699 | target: search binary_collection, and check the result: distance |
||
700 | method: compare the return distance value with value computed with SUPER |
||
701 | expected: the return distance equals to the computed value |
||
702 | ''' |
||
703 | # from scipy.spatial import distance |
||
704 | top_k = 3 |
||
705 | nprobe = 512 |
||
706 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
707 | index_type = "FLAT" |
||
708 | index_param = { |
||
709 | "nlist": 16384, |
||
710 | "metric_type": "SUBSTRUCTURE" |
||
711 | } |
||
712 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
713 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
714 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
715 | query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2) |
||
716 | search_param = get_search_param(index_type) |
||
717 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
718 | logging.getLogger().info(status) |
||
719 | logging.getLogger().info(result) |
||
720 | assert len(result[0]) == 2 |
||
721 | assert len(result[1]) == 2 |
||
722 | assert result[0][0].id in ids |
||
723 | assert result[0][0].distance <= epsilon |
||
724 | assert result[1][0].id in ids |
||
725 | assert result[1][0].distance <= epsilon |
||
726 | |||
727 | View Code Duplication | def _test_search_distance_tanimoto_flat_index(self, connect, binary_collection): |
|
728 | ''' |
||
729 | target: search binary_collection, and check the result: distance |
||
730 | method: compare the return distance value with value computed with Inner product |
||
731 | expected: the return distance equals to the computed value |
||
732 | ''' |
||
733 | # from scipy.spatial import distance |
||
734 | nprobe = 512 |
||
735 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
736 | index_type = "FLAT" |
||
737 | index_param = { |
||
738 | "nlist": 16384, |
||
739 | "metric_type": "TANIMOTO" |
||
740 | } |
||
741 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
742 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
743 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
744 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
745 | distance_0 = tanimoto(query_int_vectors[0], int_vectors[0]) |
||
746 | distance_1 = tanimoto(query_int_vectors[0], int_vectors[1]) |
||
747 | search_param = get_search_param(index_type) |
||
748 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
749 | logging.getLogger().info(status) |
||
750 | logging.getLogger().info(result) |
||
751 | assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon |
||
752 | |||
753 | @pytest.mark.timeout(30) |
||
754 | def test_search_concurrent_multithreads(self, connect, args): |
||
755 | ''' |
||
756 | target: test concurrent search with multiprocessess |
||
757 | method: search with 10 processes, each process uses dependent connection |
||
758 | expected: status ok and the returned vectors should be query_records |
||
759 | ''' |
||
760 | nb = 100 |
||
761 | top_k = 10 |
||
762 | threads_num = 4 |
||
763 | threads = [] |
||
764 | collection = gen_unique_str(collection_id) |
||
765 | uri = "tcp://%s:%s" % (args["ip"], args["port"]) |
||
766 | # create collection |
||
767 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
768 | milvus.create_collection(collection, default_fields) |
||
769 | entities, ids = init_data(milvus, collection) |
||
770 | |||
771 | def search(milvus): |
||
772 | res = connect.search(collection, default_query) |
||
773 | assert len(res) == 1 |
||
774 | assert res[0]._entities[0].id in ids |
||
775 | assert res[0]._distances[0] < epsilon |
||
776 | |||
777 | for i in range(threads_num): |
||
778 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
779 | t = threading.Thread(target=search, args=(milvus,)) |
||
780 | threads.append(t) |
||
781 | t.start() |
||
782 | time.sleep(0.2) |
||
783 | for t in threads: |
||
784 | t.join() |
||
785 | |||
786 | @pytest.mark.timeout(30) |
||
787 | def test_search_concurrent_multithreads_single_connection(self, connect, args): |
||
788 | ''' |
||
789 | target: test concurrent search with multiprocessess |
||
790 | method: search with 10 processes, each process uses dependent connection |
||
791 | expected: status ok and the returned vectors should be query_records |
||
792 | ''' |
||
793 | nb = 100 |
||
794 | top_k = 10 |
||
795 | threads_num = 4 |
||
796 | threads = [] |
||
797 | collection = gen_unique_str(collection_id) |
||
798 | uri = "tcp://%s:%s" % (args["ip"], args["port"]) |
||
799 | # create collection |
||
800 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
801 | milvus.create_collection(collection, default_fields) |
||
802 | entities, ids = init_data(milvus, collection) |
||
803 | |||
804 | def search(milvus): |
||
805 | res = connect.search(collection, default_query) |
||
806 | assert len(res) == 1 |
||
807 | assert res[0]._entities[0].id in ids |
||
808 | assert res[0]._distances[0] < epsilon |
||
809 | |||
810 | for i in range(threads_num): |
||
811 | t = threading.Thread(target=search, args=(milvus,)) |
||
812 | threads.append(t) |
||
813 | t.start() |
||
814 | time.sleep(0.2) |
||
815 | for t in threads: |
||
816 | t.join() |
||
817 | |||
818 | def test_search_multi_collections(self, connect, args): |
||
819 | ''' |
||
820 | target: test search multi collections of L2 |
||
821 | method: add vectors into 10 collections, and search |
||
822 | expected: search status ok, the length of result |
||
823 | ''' |
||
824 | num = 10 |
||
825 | top_k = 10 |
||
826 | nq = 20 |
||
827 | for i in range(num): |
||
828 | collection = gen_unique_str(collection_id + str(i)) |
||
829 | connect.create_collection(collection, default_fields) |
||
830 | entities, ids = init_data(connect, collection) |
||
831 | assert len(ids) == nb |
||
832 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
833 | res = connect.search(collection, query) |
||
834 | assert len(res) == nq |
||
835 | for i in range(nq): |
||
836 | assert check_id_result(res[i], ids[i]) |
||
837 | assert res[i]._distances[0] < epsilon |
||
838 | assert res[i]._distances[1] > epsilon |
||
839 | |||
840 | |||
841 | class TestSearchDSL(object): |
||
842 | """ |
||
843 | ****************************************************************** |
||
844 | # The following cases are used to build invalid query expr |
||
845 | ****************************************************************** |
||
846 | """ |
||
847 | |||
848 | def test_query_no_must(self, connect, collection): |
||
849 | ''' |
||
850 | method: build query without must expr |
||
851 | expected: error raised |
||
852 | ''' |
||
853 | # entities, ids = init_data(connect, collection) |
||
854 | query = update_query_expr(default_query, keep_old=False) |
||
855 | with pytest.raises(Exception) as e: |
||
856 | res = connect.search(collection, query) |
||
857 | |||
858 | def test_query_no_vector_term_only(self, connect, collection): |
||
859 | ''' |
||
860 | method: build query without must expr |
||
861 | expected: error raised |
||
862 | ''' |
||
863 | # entities, ids = init_data(connect, collection) |
||
864 | expr = { |
||
865 | "must": [gen_default_term_expr] |
||
866 | } |
||
867 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
868 | with pytest.raises(Exception) as e: |
||
869 | res = connect.search(collection, query) |
||
870 | |||
871 | def test_query_vector_only(self, connect, collection): |
||
872 | entities, ids = init_data(connect, collection) |
||
873 | res = connect.search(collection, default_query) |
||
874 | assert len(res) == nq |
||
875 | assert len(res[0]) == top_k |
||
876 | |||
877 | def test_query_wrong_format(self, connect, collection): |
||
878 | ''' |
||
879 | method: build query without must expr, with wrong expr name |
||
880 | expected: error raised |
||
881 | ''' |
||
882 | # entities, ids = init_data(connect, collection) |
||
883 | expr = { |
||
884 | "must1": [gen_default_term_expr] |
||
885 | } |
||
886 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
887 | with pytest.raises(Exception) as e: |
||
888 | res = connect.search(collection, query) |
||
889 | |||
890 | def test_query_empty(self, connect, collection): |
||
891 | ''' |
||
892 | method: search with empty query |
||
893 | expected: error raised |
||
894 | ''' |
||
895 | query = {} |
||
896 | with pytest.raises(Exception) as e: |
||
897 | res = connect.search(collection, query) |
||
898 | |||
899 | """ |
||
900 | ****************************************************************** |
||
901 | # The following cases are used to build valid query expr |
||
902 | ****************************************************************** |
||
903 | """ |
||
904 | |||
905 | @pytest.mark.level(2) |
||
906 | def test_query_term_value_not_in(self, connect, collection): |
||
907 | ''' |
||
908 | method: build query with vector and term expr, with no term can be filtered |
||
909 | expected: filter pass |
||
910 | ''' |
||
911 | entities, ids = init_data(connect, collection) |
||
912 | expr = { |
||
913 | "must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[100000])]} |
||
914 | query = update_query_expr(default_query, expr=expr) |
||
915 | res = connect.search(collection, query) |
||
916 | assert len(res) == nq |
||
917 | assert len(res[0]) == 0 |
||
918 | # TODO: |
||
919 | |||
920 | @pytest.mark.level(2) |
||
921 | def test_query_term_value_all_in(self, connect, collection): |
||
922 | ''' |
||
923 | method: build query with vector and term expr, with all term can be filtered |
||
924 | expected: filter pass |
||
925 | ''' |
||
926 | entities, ids = init_data(connect, collection) |
||
927 | expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1])]} |
||
928 | query = update_query_expr(default_query, expr=expr) |
||
929 | res = connect.search(collection, query) |
||
930 | assert len(res) == nq |
||
931 | assert len(res[0]) == 1 |
||
932 | # TODO: |
||
933 | |||
934 | View Code Duplication | @pytest.mark.level(2) |
|
935 | def test_query_term_values_not_in(self, connect, collection): |
||
936 | ''' |
||
937 | method: build query with vector and term expr, with no term can be filtered |
||
938 | expected: filter pass |
||
939 | ''' |
||
940 | entities, ids = init_data(connect, collection) |
||
941 | expr = {"must": [gen_default_vector_expr(default_query), |
||
942 | gen_default_term_expr(values=[i for i in range(100000, 100010)])]} |
||
943 | query = update_query_expr(default_query, expr=expr) |
||
944 | res = connect.search(collection, query) |
||
945 | assert len(res) == nq |
||
946 | assert len(res[0]) == 0 |
||
947 | # TODO: |
||
948 | |||
949 | def test_query_term_values_all_in(self, connect, collection): |
||
950 | ''' |
||
951 | method: build query with vector and term expr, with all term can be filtered |
||
952 | expected: filter pass |
||
953 | ''' |
||
954 | entities, ids = init_data(connect, collection) |
||
955 | expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr()]} |
||
956 | query = update_query_expr(default_query, expr=expr) |
||
957 | res = connect.search(collection, query) |
||
958 | assert len(res) == nq |
||
959 | assert len(res[0]) == top_k |
||
960 | # TODO: |
||
961 | |||
962 | View Code Duplication | def test_query_term_values_parts_in(self, connect, collection): |
|
963 | ''' |
||
964 | method: build query with vector and term expr, with parts of term can be filtered |
||
965 | expected: filter pass |
||
966 | ''' |
||
967 | entities, ids = init_data(connect, collection) |
||
968 | expr = {"must": [gen_default_vector_expr(default_query), |
||
969 | gen_default_term_expr(values=[i for i in range(nb // 2, nb + nb // 2)])]} |
||
970 | query = update_query_expr(default_query, expr=expr) |
||
971 | res = connect.search(collection, query) |
||
972 | assert len(res) == nq |
||
973 | assert len(res[0]) == top_k |
||
974 | # TODO: |
||
975 | |||
976 | View Code Duplication | @pytest.mark.level(2) |
|
977 | def test_query_term_values_repeat(self, connect, collection): |
||
978 | ''' |
||
979 | method: build query with vector and term expr, with the same values |
||
980 | expected: filter pass |
||
981 | ''' |
||
982 | entities, ids = init_data(connect, collection) |
||
983 | expr = { |
||
984 | "must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1 for i in range(1, nb)])]} |
||
985 | query = update_query_expr(default_query, expr=expr) |
||
986 | res = connect.search(collection, query) |
||
987 | assert len(res) == nq |
||
988 | assert len(res[0]) == 1 |
||
989 | # TODO: |
||
990 | |||
991 | def test_query_term_value_empty(self, connect, collection): |
||
992 | ''' |
||
993 | method: build query with term value empty |
||
994 | expected: return null |
||
995 | ''' |
||
996 | expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[])]} |
||
997 | query = update_query_expr(default_query, expr=expr) |
||
998 | res = connect.search(collection, query) |
||
999 | assert len(res) == nq |
||
1000 | assert len(res[0]) == 0 |
||
1001 | |||
1002 | """ |
||
1003 | ****************************************************************** |
||
1004 | # The following cases are used to build invalid term query expr |
||
1005 | ****************************************************************** |
||
1006 | """ |
||
1007 | |||
1008 | # TODO |
||
1009 | @pytest.mark.level(2) |
||
1010 | def test_query_term_key_error(self, connect, collection): |
||
1011 | ''' |
||
1012 | method: build query with term key error |
||
1013 | expected: Exception raised |
||
1014 | ''' |
||
1015 | expr = {"must": [gen_default_vector_expr(default_query), |
||
1016 | gen_default_term_expr(keyword="terrm", values=[i for i in range(nb // 2)])]} |
||
1017 | query = update_query_expr(default_query, expr=expr) |
||
1018 | with pytest.raises(Exception) as e: |
||
1019 | res = connect.search(collection, query) |
||
1020 | |||
1021 | @pytest.fixture( |
||
1022 | scope="function", |
||
1023 | params=gen_invalid_term() |
||
1024 | ) |
||
1025 | def get_invalid_term(self, request): |
||
1026 | return request.param |
||
1027 | |||
1028 | # TODO |
||
1029 | @pytest.mark.level(2) |
||
1030 | def test_query_term_wrong_format(self, connect, collection, get_invalid_term): |
||
1031 | ''' |
||
1032 | method: build query with wrong format term |
||
1033 | expected: Exception raised |
||
1034 | ''' |
||
1035 | entities, ids = init_data(connect, collection) |
||
1036 | term = get_invalid_term |
||
1037 | expr = {"must": [gen_default_vector_expr(default_query), term]} |
||
1038 | query = update_query_expr(default_query, expr=expr) |
||
1039 | with pytest.raises(Exception) as e: |
||
1040 | res = connect.search(collection, query) |
||
1041 | |||
1042 | # TODO |
||
1043 | @pytest.mark.level(2) |
||
1044 | def test_query_term_field_named_term(self, connect, collection): |
||
1045 | ''' |
||
1046 | method: build query with field named "term" |
||
1047 | expected: error raised |
||
1048 | ''' |
||
1049 | term_fields = add_field_default(default_fields, field_name="term") |
||
1050 | collection_term = gen_unique_str("term") |
||
1051 | connect.create_collection(collection_term, term_fields) |
||
1052 | term_entities = add_field(entities, field_name="term") |
||
1053 | ids = connect.insert(collection_term, term_entities) |
||
1054 | assert len(ids) == nb |
||
1055 | connect.flush([collection_term]) |
||
1056 | count = connect.count_entities(collection_term) |
||
1057 | assert count == nb |
||
1058 | term_param = {"term": {"term": {"values": [i for i in range(nb // 2)]}}} |
||
1059 | expr = {"must": [gen_default_vector_expr(default_query), |
||
1060 | term_param]} |
||
1061 | query = update_query_expr(default_query, expr=expr) |
||
1062 | res = connect.search(collection_term, query) |
||
1063 | assert len(res) == nq |
||
1064 | assert len(res[0]) == top_k |
||
1065 | connect.drop_collection(collection_term) |
||
1066 | |||
1067 | """ |
||
1068 | ****************************************************************** |
||
1069 | # The following cases are used to build valid range query expr |
||
1070 | ****************************************************************** |
||
1071 | """ |
||
1072 | |||
1073 | # TODO |
||
1074 | def test_query_range_key_error(self, connect, collection): |
||
1075 | ''' |
||
1076 | method: build query with range key error |
||
1077 | expected: Exception raised |
||
1078 | ''' |
||
1079 | range = gen_default_range_expr(keyword="ranges") |
||
1080 | expr = {"must": [gen_default_vector_expr(default_query), range]} |
||
1081 | query = update_query_expr(default_query, expr=expr) |
||
1082 | with pytest.raises(Exception) as e: |
||
1083 | res = connect.search(collection, query) |
||
1084 | |||
1085 | @pytest.fixture( |
||
1086 | scope="function", |
||
1087 | params=gen_invalid_range() |
||
1088 | ) |
||
1089 | def get_invalid_range(self, request): |
||
1090 | return request.param |
||
1091 | |||
1092 | # TODO |
||
1093 | @pytest.mark.level(2) |
||
1094 | def test_query_range_wrong_format(self, connect, collection, get_invalid_range): |
||
1095 | ''' |
||
1096 | method: build query with wrong format range |
||
1097 | expected: Exception raised |
||
1098 | ''' |
||
1099 | entities, ids = init_data(connect, collection) |
||
1100 | range = get_invalid_range |
||
1101 | expr = {"must": [gen_default_vector_expr(default_query), range]} |
||
1102 | query = update_query_expr(default_query, expr=expr) |
||
1103 | with pytest.raises(Exception) as e: |
||
1104 | res = connect.search(collection, query) |
||
1105 | |||
1106 | @pytest.fixture( |
||
1107 | scope="function", |
||
1108 | params=gen_valid_ranges() |
||
1109 | ) |
||
1110 | def get_valid_ranges(self, request): |
||
1111 | return request.param |
||
1112 | |||
1113 | # TODO: |
||
1114 | def _test_query_range_valid_ranges(self, connect, collection, get_valid_ranges): |
||
1115 | ''' |
||
1116 | method: build query with valid ranges |
||
1117 | expected: pass |
||
1118 | ''' |
||
1119 | entities, ids = init_data(connect, collection) |
||
1120 | ranges = get_valid_ranges |
||
1121 | range = gen_default_range_expr(ranges=ranges) |
||
1122 | expr = {"must": [gen_default_vector_expr(default_query), range]} |
||
1123 | query = update_query_expr(default_query, expr=expr) |
||
1124 | res = connect.search(collection, query) |
||
1125 | assert len(res) == nq |
||
1126 | assert len(res[0]) == top_k |
||
1127 | |||
1128 | |||
1129 | class TestSearchDSLBools(object): |
||
1130 | """ |
||
1131 | ****************************************************************** |
||
1132 | # The following cases are used to build invalid query expr |
||
1133 | ****************************************************************** |
||
1134 | """ |
||
1135 | |||
1136 | def test_query_no_bool(self, connect, collection): |
||
1137 | ''' |
||
1138 | method: build query without bool expr |
||
1139 | expected: error raised |
||
1140 | ''' |
||
1141 | expr = {"bool1": {}} |
||
1142 | with pytest.raises(Exception) as e: |
||
1143 | res = connect.search(collection, query) |
||
1144 | |||
1145 | def test_query_should_only_term(self, connect, collection): |
||
1146 | ''' |
||
1147 | method: build query without must, with should.term instead |
||
1148 | expected: error raised |
||
1149 | ''' |
||
1150 | expr = {"should": gen_default_term_expr} |
||
1151 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
1152 | with pytest.raises(Exception) as e: |
||
1153 | res = connect.search(collection, query) |
||
1154 | |||
1155 | def test_query_should_only_vector(self, connect, collection): |
||
1156 | ''' |
||
1157 | method: build query without must, with should.vector instead |
||
1158 | expected: error raised |
||
1159 | ''' |
||
1160 | expr = {"should": default_query["bool"]["must"]} |
||
1161 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
1162 | with pytest.raises(Exception) as e: |
||
1163 | res = connect.search(collection, query) |
||
1164 | |||
1165 | def test_query_must_not_only_term(self, connect, collection): |
||
1166 | ''' |
||
1167 | method: build query without must, with must_not.term instead |
||
1168 | expected: error raised |
||
1169 | ''' |
||
1170 | expr = {"must_not": gen_default_term_expr} |
||
1171 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
1172 | with pytest.raises(Exception) as e: |
||
1173 | res = connect.search(collection, query) |
||
1174 | |||
1175 | def test_query_must_not_vector(self, connect, collection): |
||
1176 | ''' |
||
1177 | method: build query without must, with must_not.vector instead |
||
1178 | expected: error raised |
||
1179 | ''' |
||
1180 | expr = {"must_not": default_query["bool"]["must"]} |
||
1181 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
1182 | with pytest.raises(Exception) as e: |
||
1183 | res = connect.search(collection, query) |
||
1184 | |||
1185 | def test_query_must_should(self, connect, collection): |
||
1186 | ''' |
||
1187 | method: build query must, and with should.term |
||
1188 | expected: error raised |
||
1189 | ''' |
||
1190 | expr = {"should": gen_default_term_expr} |
||
1191 | query = update_query_expr(default_query, keep_old=True, expr=expr) |
||
1192 | with pytest.raises(Exception) as e: |
||
1193 | res = connect.search(collection, query) |
||
1194 | |||
1195 | |||
1196 | """ |
||
1197 | ****************************************************************** |
||
1198 | # The following cases are used to test `search` function |
||
1199 | # with invalid collection_name, or invalid query expr |
||
1200 | ****************************************************************** |
||
1201 | """ |
||
1202 | |||
1203 | |||
1204 | class TestSearchInvalid(object): |
||
1205 | """ |
||
1206 | Test search collection with invalid collection names |
||
1207 | """ |
||
1208 | |||
1209 | @pytest.fixture( |
||
1210 | scope="function", |
||
1211 | params=gen_invalid_strs() |
||
1212 | ) |
||
1213 | def get_collection_name(self, request): |
||
1214 | yield request.param |
||
1215 | |||
1216 | @pytest.fixture( |
||
1217 | scope="function", |
||
1218 | params=gen_invalid_strs() |
||
1219 | ) |
||
1220 | def get_invalid_tag(self, request): |
||
1221 | yield request.param |
||
1222 | |||
1223 | @pytest.fixture( |
||
1224 | scope="function", |
||
1225 | params=gen_invalid_strs() |
||
1226 | ) |
||
1227 | def get_invalid_field(self, request): |
||
1228 | yield request.param |
||
1229 | |||
1230 | @pytest.fixture( |
||
1231 | scope="function", |
||
1232 | params=gen_simple_index() |
||
1233 | ) |
||
1234 | def get_simple_index(self, request, connect): |
||
1235 | if str(connect._cmd("mode")) == "CPU": |
||
1236 | if request.param["index_type"] in index_cpu_not_support(): |
||
1237 | pytest.skip("sq8h not support in CPU mode") |
||
1238 | return request.param |
||
1239 | |||
1240 | @pytest.mark.level(2) |
||
1241 | def test_search_with_invalid_collection(self, connect, get_collection_name): |
||
1242 | collection_name = get_collection_name |
||
1243 | with pytest.raises(Exception) as e: |
||
1244 | res = connect.search(collection_name, default_query) |
||
1245 | |||
1246 | @pytest.mark.level(1) |
||
1247 | def test_search_with_invalid_tag(self, connect, collection): |
||
1248 | tag = " " |
||
1249 | with pytest.raises(Exception) as e: |
||
1250 | res = connect.search(collection, default_query, partition_tags=tag) |
||
1251 | |||
1252 | @pytest.mark.level(2) |
||
1253 | def test_search_with_invalid_field_name(self, connect, collection, get_invalid_field): |
||
1254 | fields = [get_invalid_field] |
||
1255 | with pytest.raises(Exception) as e: |
||
1256 | res = connect.search(collection, default_query, fields=fields) |
||
1257 | |||
1258 | @pytest.mark.level(1) |
||
1259 | def test_search_with_not_existed_field_name(self, connect, collection): |
||
1260 | fields = [gen_unique_str("field_name")] |
||
1261 | with pytest.raises(Exception) as e: |
||
1262 | res = connect.search(collection, default_query, fields=fields) |
||
1263 | |||
1264 | """ |
||
1265 | Test search collection with invalid query |
||
1266 | """ |
||
1267 | |||
1268 | @pytest.fixture( |
||
1269 | scope="function", |
||
1270 | params=gen_invalid_ints() |
||
1271 | ) |
||
1272 | def get_top_k(self, request): |
||
1273 | yield request.param |
||
1274 | |||
1275 | @pytest.mark.level(1) |
||
1276 | def test_search_with_invalid_top_k(self, connect, collection, get_top_k): |
||
1277 | ''' |
||
1278 | target: test search fuction, with the wrong top_k |
||
1279 | method: search with top_k |
||
1280 | expected: raise an error, and the connection is normal |
||
1281 | ''' |
||
1282 | top_k = get_top_k |
||
1283 | default_query["bool"]["must"][0]["vector"][field_name]["topk"] = top_k |
||
1284 | with pytest.raises(Exception) as e: |
||
1285 | res = connect.search(collection, default_query) |
||
1286 | |||
1287 | """ |
||
1288 | Test search collection with invalid search params |
||
1289 | """ |
||
1290 | |||
1291 | @pytest.fixture( |
||
1292 | scope="function", |
||
1293 | params=gen_invaild_search_params() |
||
1294 | ) |
||
1295 | def get_search_params(self, request): |
||
1296 | yield request.param |
||
1297 | |||
1298 | # TODO: This case can all pass, but it's too slow |
||
1299 | View Code Duplication | @pytest.mark.level(2) |
|
1300 | def _test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params): |
||
1301 | ''' |
||
1302 | target: test search fuction, with the wrong nprobe |
||
1303 | method: search with nprobe |
||
1304 | expected: raise an error, and the connection is normal |
||
1305 | ''' |
||
1306 | search_params = get_search_params |
||
1307 | index_type = get_simple_index["index_type"] |
||
1308 | entities, ids = init_data(connect, collection) |
||
1309 | connect.create_index(collection, field_name, get_simple_index) |
||
1310 | if search_params["index_type"] != index_type: |
||
1311 | pytest.skip("Skip case") |
||
1312 | query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params=search_params["search_params"]) |
||
1313 | with pytest.raises(Exception) as e: |
||
1314 | res = connect.search(collection, query) |
||
1315 | |||
1316 | View Code Duplication | @pytest.mark.level(2) |
|
1317 | def test_search_with_empty_params(self, connect, collection, args, get_simple_index): |
||
1318 | ''' |
||
1319 | target: test search fuction, with empty search params |
||
1320 | method: search with params |
||
1321 | expected: raise an error, and the connection is normal |
||
1322 | ''' |
||
1323 | index_type = get_simple_index["index_type"] |
||
1324 | if args["handler"] == "HTTP": |
||
1325 | pytest.skip("skip in http mode") |
||
1326 | if index_type == "FLAT": |
||
1327 | pytest.skip("skip in FLAT index") |
||
1328 | entities, ids = init_data(connect, collection) |
||
1329 | connect.create_index(collection, field_name, get_simple_index) |
||
1330 | query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params={}) |
||
1331 | with pytest.raises(Exception) as e: |
||
1332 | res = connect.search(collection, query) |
||
1333 | |||
1334 | |||
1335 | def check_id_result(result, id): |
||
1336 | limit_in = 5 |
||
1337 | ids = [entity.id for entity in result] |
||
1338 | if len(result) >= limit_in: |
||
1339 | return id in ids[:limit_in] |
||
1340 | else: |
||
1341 | return id in ids |
||
1342 |