Total Complexity | 149 |
Total Lines | 1174 |
Duplicated Lines | 25.81 % |
Changes | 0 |
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like test_search often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | import time |
||
2 | import pdb |
||
3 | import copy |
||
4 | import threading |
||
5 | import logging |
||
6 | from multiprocessing import Pool, Process |
||
7 | import pytest |
||
8 | import numpy as np |
||
9 | |||
10 | from milvus import DataType |
||
11 | from utils import * |
||
12 | |||
13 | dim = 128 |
||
14 | segment_row_count = 5000 |
||
15 | top_k_limit = 2048 |
||
16 | collection_id = "search" |
||
17 | tag = "1970-01-01" |
||
18 | insert_interval_time = 1.5 |
||
19 | nb = 6000 |
||
20 | top_k = 10 |
||
21 | nprobe = 1 |
||
22 | epsilon = 0.001 |
||
23 | field_name = default_float_vec_field_name |
||
|
|||
24 | default_fields = gen_default_fields() |
||
25 | search_param = {"nprobe": 1} |
||
26 | entity = gen_entities(1, is_normal=True) |
||
27 | raw_vector, binary_entity = gen_binary_entities(1) |
||
28 | entities = gen_entities(nb, is_normal=True) |
||
29 | raw_vectors, binary_entities = gen_binary_entities(nb) |
||
30 | default_query, default_query_vecs = gen_query_vectors(field_name, entities, top_k, 1) |
||
31 | |||
32 | def init_data(connect, collection, nb=6000, partition_tags=None): |
||
33 | ''' |
||
34 | Generate entities and add it in collection |
||
35 | ''' |
||
36 | global entities |
||
37 | if nb == 6000: |
||
38 | insert_entities = entities |
||
39 | else: |
||
40 | insert_entities = gen_entities(nb, is_normal=True) |
||
41 | if partition_tags is None: |
||
42 | ids = connect.insert(collection, insert_entities) |
||
43 | else: |
||
44 | ids = connect.insert(collection, insert_entities, partition_tag=partition_tags) |
||
45 | connect.flush([collection]) |
||
46 | return insert_entities, ids |
||
47 | |||
48 | def init_binary_data(connect, collection, nb=6000, insert=True, partition_tags=None): |
||
49 | ''' |
||
50 | Generate entities and add it in collection |
||
51 | ''' |
||
52 | ids = [] |
||
53 | global binary_entities |
||
54 | global raw_vectors |
||
55 | if nb == 6000: |
||
56 | insert_entities = binary_entities |
||
57 | insert_raw_vectors = raw_vectors |
||
58 | else: |
||
59 | insert_raw_vectors, insert_entities = gen_binary_entities(nb) |
||
60 | if insert is True: |
||
61 | if partition_tags is None: |
||
62 | ids = connect.insert(collection, insert_entities) |
||
63 | else: |
||
64 | ids = connect.insert(collection, insert_entities, partition_tag=partition_tags) |
||
65 | connect.flush([collection]) |
||
66 | return insert_raw_vectors, insert_entities, ids |
||
67 | |||
68 | |||
69 | class TestSearchBase: |
||
70 | |||
71 | |||
72 | """ |
||
73 | generate valid create_index params |
||
74 | """ |
||
75 | @pytest.fixture( |
||
76 | scope="function", |
||
77 | params=gen_index() |
||
78 | ) |
||
79 | def get_index(self, request, connect): |
||
80 | if str(connect._cmd("mode")) == "CPU": |
||
81 | if request.param["index_type"] in index_cpu_not_support(): |
||
82 | pytest.skip("sq8h not support in CPU mode") |
||
83 | return request.param |
||
84 | |||
85 | @pytest.fixture( |
||
86 | scope="function", |
||
87 | params=gen_simple_index() |
||
88 | ) |
||
89 | def get_simple_index(self, request, connect): |
||
90 | if str(connect._cmd("mode")) == "CPU": |
||
91 | if request.param["index_type"] in index_cpu_not_support(): |
||
92 | pytest.skip("sq8h not support in CPU mode") |
||
93 | return request.param |
||
94 | |||
95 | @pytest.fixture( |
||
96 | scope="function", |
||
97 | params=gen_simple_index() |
||
98 | ) |
||
99 | def get_jaccard_index(self, request, connect): |
||
100 | logging.getLogger().info(request.param) |
||
101 | if request.param["index_type"] in binary_support(): |
||
102 | return request.param |
||
103 | else: |
||
104 | pytest.skip("Skip index Temporary") |
||
105 | |||
106 | @pytest.fixture( |
||
107 | scope="function", |
||
108 | params=gen_simple_index() |
||
109 | ) |
||
110 | def get_hamming_index(self, request, connect): |
||
111 | logging.getLogger().info(request.param) |
||
112 | if request.param["index_type"] in binary_support(): |
||
113 | return request.param |
||
114 | else: |
||
115 | pytest.skip("Skip index Temporary") |
||
116 | |||
117 | @pytest.fixture( |
||
118 | scope="function", |
||
119 | params=gen_simple_index() |
||
120 | ) |
||
121 | def get_structure_index(self, request, connect): |
||
122 | logging.getLogger().info(request.param) |
||
123 | if request.param["index_type"] == "FLAT": |
||
124 | return request.param |
||
125 | else: |
||
126 | pytest.skip("Skip index Temporary") |
||
127 | |||
128 | """ |
||
129 | generate top-k params |
||
130 | """ |
||
131 | @pytest.fixture( |
||
132 | scope="function", |
||
133 | params=[1, 10, 2049] |
||
134 | ) |
||
135 | def get_top_k(self, request): |
||
136 | yield request.param |
||
137 | |||
138 | @pytest.fixture( |
||
139 | scope="function", |
||
140 | params=[1, 10, 1100] |
||
141 | ) |
||
142 | def get_nq(self, request): |
||
143 | yield request.param |
||
144 | |||
145 | def test_search_flat(self, connect, collection, get_top_k, get_nq): |
||
146 | ''' |
||
147 | target: test basic search fuction, all the search params is corrent, change top-k value |
||
148 | method: search with the given vectors, check the result |
||
149 | expected: the length of the result is top_k |
||
150 | ''' |
||
151 | top_k = get_top_k |
||
152 | nq = get_nq |
||
153 | entities, ids = init_data(connect, collection) |
||
154 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
155 | if top_k <= top_k_limit: |
||
156 | res = connect.search(collection, query) |
||
157 | assert len(res[0]) == top_k |
||
158 | assert res[0]._distances[0] <= epsilon |
||
159 | assert check_id_result(res[0], ids[0]) |
||
160 | else: |
||
161 | with pytest.raises(Exception) as e: |
||
162 | res = connect.search(collection, query) |
||
163 | |||
164 | def test_search_field(self, connect, collection, get_top_k, get_nq): |
||
165 | ''' |
||
166 | target: test basic search fuction, all the search params is corrent, change top-k value |
||
167 | method: search with the given vectors, check the result |
||
168 | expected: the length of the result is top_k |
||
169 | ''' |
||
170 | top_k = get_top_k |
||
171 | nq = get_nq |
||
172 | entities, ids = init_data(connect, collection) |
||
173 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
174 | if top_k <= top_k_limit: |
||
175 | res = connect.search(collection, query, fields=["float_vector"]) |
||
176 | assert len(res[0]) == top_k |
||
177 | assert res[0]._distances[0] <= epsilon |
||
178 | assert check_id_result(res[0], ids[0]) |
||
179 | # TODO |
||
180 | res = connect.search(collection, query, fields=["float"]) |
||
181 | # TODO |
||
182 | else: |
||
183 | with pytest.raises(Exception) as e: |
||
184 | res = connect.search(collection, query) |
||
185 | |||
186 | @pytest.mark.level(2) |
||
187 | def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
188 | ''' |
||
189 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
190 | method: search with the given vectors, check the result |
||
191 | expected: the length of the result is top_k |
||
192 | ''' |
||
193 | top_k = get_top_k |
||
194 | nq = get_nq |
||
195 | |||
196 | index_type = get_simple_index["index_type"] |
||
197 | if index_type == "IVF_PQ": |
||
198 | pytest.skip("Skip PQ") |
||
199 | entities, ids = init_data(connect, collection) |
||
200 | connect.create_index(collection, field_name, get_simple_index) |
||
201 | search_param = get_search_param(index_type) |
||
202 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
203 | if top_k > top_k_limit: |
||
204 | with pytest.raises(Exception) as e: |
||
205 | res = connect.search(collection, query) |
||
206 | else: |
||
207 | res = connect.search(collection, query) |
||
208 | assert len(res) == nq |
||
209 | assert len(res[0]) >= top_k |
||
210 | assert res[0]._distances[0] < epsilon |
||
211 | assert check_id_result(res[0], ids[0]) |
||
212 | |||
213 | def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
214 | ''' |
||
215 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
216 | method: add vectors into collection, search with the given vectors, check the result |
||
217 | expected: the length of the result is top_k, search collection with partition tag return empty |
||
218 | ''' |
||
219 | top_k = get_top_k |
||
220 | nq = get_nq |
||
221 | |||
222 | index_type = get_simple_index["index_type"] |
||
223 | if index_type == "IVF_PQ": |
||
224 | pytest.skip("Skip PQ") |
||
225 | connect.create_partition(collection, tag) |
||
226 | entities, ids = init_data(connect, collection) |
||
227 | connect.create_index(collection, field_name, get_simple_index) |
||
228 | search_param = get_search_param(index_type) |
||
229 | query, vecs = gen_query_vectors_(field_name, entities, top_k, nq, search_params=search_param) |
||
230 | if top_k > top_k_limit: |
||
231 | with pytest.raises(Exception) as e: |
||
232 | res = connect.search(collection, query) |
||
233 | else: |
||
234 | res = connect.search(collection, query) |
||
235 | assert len(res) == nq |
||
236 | assert len(res[0]) >= top_k |
||
237 | assert res[0]._distances[0] < epsilon |
||
238 | assert check_id_result(res[0], ids[0]) |
||
239 | res = connect.search(collection, query, partition_tags=[tag]) |
||
240 | assert len(res) == nq |
||
241 | |||
242 | def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
243 | ''' |
||
244 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
245 | method: search with the given vectors, check the result |
||
246 | expected: the length of the result is top_k |
||
247 | ''' |
||
248 | top_k = get_top_k |
||
249 | nq = get_nq |
||
250 | |||
251 | index_type = get_simple_index["index_type"] |
||
252 | if index_type == "IVF_PQ": |
||
253 | pytest.skip("Skip PQ") |
||
254 | connect.create_partition(collection, tag) |
||
255 | entities, ids = init_data(connect, collection, partition_tags=tag) |
||
256 | connect.create_index(collection, field_name, get_simple_index) |
||
257 | search_param = get_search_param(index_type) |
||
258 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
259 | for tags in [[tag], [tag, "new_tag"]]: |
||
260 | if top_k > top_k_limit: |
||
261 | with pytest.raises(Exception) as e: |
||
262 | res = connect.search(collection, query, partition_tags=tags) |
||
263 | else: |
||
264 | res = connect.search(collection, query, partition_tags=tags) |
||
265 | assert len(res) == nq |
||
266 | assert len(res[0]) >= top_k |
||
267 | assert res[0]._distances[0] < epsilon |
||
268 | assert check_id_result(res[0], ids[0]) |
||
269 | |||
270 | @pytest.mark.level(2) |
||
271 | def test_search_index_partition_C(self, connect, collection, get_top_k, get_nq): |
||
272 | ''' |
||
273 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
274 | method: search with the given vectors and tag (tag name not existed in collection), check the result |
||
275 | expected: error raised |
||
276 | ''' |
||
277 | top_k = get_top_k |
||
278 | nq = get_nq |
||
279 | entities, ids = init_data(connect, collection) |
||
280 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq) |
||
281 | if top_k > top_k_limit: |
||
282 | with pytest.raises(Exception) as e: |
||
283 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
284 | else: |
||
285 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
286 | assert len(res) == nq |
||
287 | assert len(res[0]) == 0 |
||
288 | |||
289 | View Code Duplication | @pytest.mark.level(2) |
|
290 | def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k): |
||
291 | ''' |
||
292 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
293 | method: search collection with the given vectors and tags, check the result |
||
294 | expected: the length of the result is top_k |
||
295 | ''' |
||
296 | top_k = get_top_k |
||
297 | nq = 2 |
||
298 | new_tag = "new_tag" |
||
299 | index_type = get_simple_index["index_type"] |
||
300 | if index_type == "IVF_PQ": |
||
301 | pytest.skip("Skip PQ") |
||
302 | connect.create_partition(collection, tag) |
||
303 | connect.create_partition(collection, new_tag) |
||
304 | entities, ids = init_data(connect, collection, partition_tags=tag) |
||
305 | new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) |
||
306 | connect.create_index(collection, field_name, get_simple_index) |
||
307 | search_param = get_search_param(index_type) |
||
308 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
309 | if top_k > top_k_limit: |
||
310 | with pytest.raises(Exception) as e: |
||
311 | res = connect.search(collection, query) |
||
312 | else: |
||
313 | res = connect.search(collection, query) |
||
314 | assert check_id_result(res[0], ids[0]) |
||
315 | assert not check_id_result(res[1], new_ids[0]) |
||
316 | assert res[0]._distances[0] < epsilon |
||
317 | assert res[1]._distances[0] < epsilon |
||
318 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
319 | assert res[0]._distances[0] > epsilon |
||
320 | assert res[1]._distances[0] > epsilon |
||
321 | |||
322 | # TODO: |
||
323 | View Code Duplication | @pytest.mark.level(2) |
|
324 | def _test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k): |
||
325 | ''' |
||
326 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
327 | method: search collection with the given vectors and tags, check the result |
||
328 | expected: the length of the result is top_k |
||
329 | ''' |
||
330 | top_k = get_top_k |
||
331 | nq = 2 |
||
332 | tag = "tag" |
||
333 | new_tag = "new_tag" |
||
334 | index_type = get_simple_index["index_type"] |
||
335 | if index_type == "IVF_PQ": |
||
336 | pytest.skip("Skip PQ") |
||
337 | connect.create_partition(collection, tag) |
||
338 | connect.create_partition(collection, new_tag) |
||
339 | entities, ids = init_data(connect, collection, partition_tags=tag) |
||
340 | new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) |
||
341 | connect.create_index(collection, field_name, get_simple_index) |
||
342 | search_param = get_search_param(index_type) |
||
343 | query, vecs = gen_query_vectors(field_name, new_entities, top_k, nq, search_params=search_param) |
||
344 | if top_k > top_k_limit: |
||
345 | with pytest.raises(Exception) as e: |
||
346 | res = connect.search(collection, query) |
||
347 | else: |
||
348 | res = connect.search(collection, query, partition_tags=["(.*)tag"]) |
||
349 | assert not check_id_result(res[0], ids[0]) |
||
350 | assert check_id_result(res[1], new_ids[0]) |
||
351 | assert res[0]._distances[0] > epsilon |
||
352 | assert res[1]._distances[0] < epsilon |
||
353 | res = connect.search(collection, query, partition_tags=["new(.*)"]) |
||
354 | assert res[0]._distances[0] > epsilon |
||
355 | assert res[1]._distances[0] < epsilon |
||
356 | |||
357 | # |
||
358 | # test for ip metric |
||
359 | # |
||
360 | @pytest.mark.level(2) |
||
361 | def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
362 | ''' |
||
363 | target: test basic search fuction, all the search params is corrent, change top-k value |
||
364 | method: search with the given vectors, check the result |
||
365 | expected: the length of the result is top_k |
||
366 | ''' |
||
367 | top_k = get_top_k |
||
368 | nq = get_nq |
||
369 | entities, ids = init_data(connect, collection) |
||
370 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP") |
||
371 | if top_k <= top_k_limit: |
||
372 | res = connect.search(collection, query) |
||
373 | assert len(res[0]) == top_k |
||
374 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
375 | assert check_id_result(res[0], ids[0]) |
||
376 | else: |
||
377 | with pytest.raises(Exception) as e: |
||
378 | res = connect.search(collection, query) |
||
379 | |||
380 | def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
381 | ''' |
||
382 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
383 | method: search with the given vectors, check the result |
||
384 | expected: the length of the result is top_k |
||
385 | ''' |
||
386 | top_k = get_top_k |
||
387 | nq = get_nq |
||
388 | |||
389 | index_type = get_simple_index["index_type"] |
||
390 | if index_type == "IVF_PQ": |
||
391 | pytest.skip("Skip PQ") |
||
392 | entities, ids = init_data(connect, collection) |
||
393 | get_simple_index["metric_type"] = "IP" |
||
394 | connect.create_index(collection, field_name, get_simple_index) |
||
395 | search_param = get_search_param(index_type) |
||
396 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param) |
||
397 | if top_k > top_k_limit: |
||
398 | with pytest.raises(Exception) as e: |
||
399 | res = connect.search(collection, query) |
||
400 | else: |
||
401 | res = connect.search(collection, query) |
||
402 | assert len(res) == nq |
||
403 | assert len(res[0]) >= top_k |
||
404 | assert check_id_result(res[0], ids[0]) |
||
405 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
406 | |||
407 | @pytest.mark.level(2) |
||
408 | def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): |
||
409 | ''' |
||
410 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
411 | method: add vectors into collection, search with the given vectors, check the result |
||
412 | expected: the length of the result is top_k, search collection with partition tag return empty |
||
413 | ''' |
||
414 | top_k = get_top_k |
||
415 | nq = get_nq |
||
416 | metric_type = "IP" |
||
417 | index_type = get_simple_index["index_type"] |
||
418 | if index_type == "IVF_PQ": |
||
419 | pytest.skip("Skip PQ") |
||
420 | connect.create_partition(collection, tag) |
||
421 | entities, ids = init_data(connect, collection) |
||
422 | get_simple_index["metric_type"] = metric_type |
||
423 | connect.create_index(collection, field_name, get_simple_index) |
||
424 | search_param = get_search_param(index_type) |
||
425 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=metric_type, search_params=search_param) |
||
426 | if top_k > top_k_limit: |
||
427 | with pytest.raises(Exception) as e: |
||
428 | res = connect.search(collection, query) |
||
429 | else: |
||
430 | res = connect.search(collection, query) |
||
431 | assert len(res) == nq |
||
432 | assert len(res[0]) >= top_k |
||
433 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
434 | assert check_id_result(res[0], ids[0]) |
||
435 | res = connect.search(collection, query, partition_tags=[tag]) |
||
436 | assert len(res) == nq |
||
437 | |||
438 | @pytest.mark.level(2) |
||
439 | def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k): |
||
440 | ''' |
||
441 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
442 | method: search collection with the given vectors and tags, check the result |
||
443 | expected: the length of the result is top_k |
||
444 | ''' |
||
445 | top_k = get_top_k |
||
446 | nq = 2 |
||
447 | metric_type = "IP" |
||
448 | new_tag = "new_tag" |
||
449 | index_type = get_simple_index["index_type"] |
||
450 | if index_type == "IVF_PQ": |
||
451 | pytest.skip("Skip PQ") |
||
452 | connect.create_partition(collection, tag) |
||
453 | connect.create_partition(collection, new_tag) |
||
454 | entities, ids = init_data(connect, collection, partition_tags=tag) |
||
455 | new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) |
||
456 | get_simple_index["metric_type"] = metric_type |
||
457 | connect.create_index(collection, field_name, get_simple_index) |
||
458 | search_param = get_search_param(index_type) |
||
459 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param) |
||
460 | if top_k > top_k_limit: |
||
461 | with pytest.raises(Exception) as e: |
||
462 | res = connect.search(collection, query) |
||
463 | else: |
||
464 | res = connect.search(collection, query) |
||
465 | assert check_id_result(res[0], ids[0]) |
||
466 | assert not check_id_result(res[1], new_ids[0]) |
||
467 | assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) |
||
468 | assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0]) |
||
469 | res = connect.search(collection, query, partition_tags=["new_tag"]) |
||
470 | assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0]) |
||
471 | # TODO: |
||
472 | # assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0]) |
||
473 | |||
474 | @pytest.mark.level(2) |
||
475 | def test_search_without_connect(self, dis_connect, collection): |
||
476 | ''' |
||
477 | target: test search vectors without connection |
||
478 | method: use dis connected instance, call search method and check if search successfully |
||
479 | expected: raise exception |
||
480 | ''' |
||
481 | with pytest.raises(Exception) as e: |
||
482 | res = dis_connect.search(collection, default_query) |
||
483 | |||
484 | def test_search_collection_name_not_existed(self, connect): |
||
485 | ''' |
||
486 | target: search collection not existed |
||
487 | method: search with the random collection_name, which is not in db |
||
488 | expected: status not ok |
||
489 | ''' |
||
490 | collection_name = gen_unique_str(collection_id) |
||
491 | with pytest.raises(Exception) as e: |
||
492 | res = connect.search(collection_name, default_query) |
||
493 | |||
494 | View Code Duplication | def test_search_distance_l2(self, connect, collection): |
|
495 | ''' |
||
496 | target: search collection, and check the result: distance |
||
497 | method: compare the return distance value with value computed with Euclidean |
||
498 | expected: the return distance equals to the computed value |
||
499 | ''' |
||
500 | nq = 2 |
||
501 | search_param = {"nprobe" : 1} |
||
502 | entities, ids = init_data(connect, collection, nb=nq) |
||
503 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param) |
||
504 | inside_query, inside_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
505 | distance_0 = l2(vecs[0], inside_vecs[0]) |
||
506 | distance_1 = l2(vecs[0], inside_vecs[1]) |
||
507 | res = connect.search(collection, query) |
||
508 | assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) |
||
509 | |||
510 | # TODO: distance problem |
||
511 | View Code Duplication | def _test_search_distance_l2_after_index(self, connect, collection, get_simple_index): |
|
512 | ''' |
||
513 | target: search collection, and check the result: distance |
||
514 | method: compare the return distance value with value computed with Inner product |
||
515 | expected: the return distance equals to the computed value |
||
516 | ''' |
||
517 | index_type = get_simple_index["index_type"] |
||
518 | nq = 2 |
||
519 | entities, ids = init_data(connect, collection) |
||
520 | connect.create_index(collection, field_name, get_simple_index) |
||
521 | search_param = get_search_param(index_type) |
||
522 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param) |
||
523 | inside_vecs = entities[-1]["values"] |
||
524 | min_distance = 1.0 |
||
525 | for i in range(nb): |
||
526 | tmp_dis = l2(vecs[0], inside_vecs[i]) |
||
527 | if min_distance > tmp_dis: |
||
528 | min_distance = tmp_dis |
||
529 | res = connect.search(collection, query) |
||
530 | assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= gen_inaccuracy(res[0]._distances[0]) |
||
531 | |||
532 | View Code Duplication | def test_search_distance_ip(self, connect, collection): |
|
533 | ''' |
||
534 | target: search collection, and check the result: distance |
||
535 | method: compare the return distance value with value computed with Inner product |
||
536 | expected: the return distance equals to the computed value |
||
537 | ''' |
||
538 | nq = 2 |
||
539 | metirc_type = "IP" |
||
540 | search_param = {"nprobe" : 1} |
||
541 | entities, ids = init_data(connect, collection, nb=nq) |
||
542 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type, search_params=search_param) |
||
543 | inside_query, inside_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
544 | distance_0 = ip(vecs[0], inside_vecs[0]) |
||
545 | distance_1 = ip(vecs[0], inside_vecs[1]) |
||
546 | res = connect.search(collection, query) |
||
547 | assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) |
||
548 | |||
549 | # TODO: distance problem |
||
550 | View Code Duplication | def _test_search_distance_ip_after_index(self, connect, collection, get_simple_index): |
|
551 | ''' |
||
552 | target: search collection, and check the result: distance |
||
553 | method: compare the return distance value with value computed with Inner product |
||
554 | expected: the return distance equals to the computed value |
||
555 | ''' |
||
556 | index_type = get_simple_index["index_type"] |
||
557 | nq = 2 |
||
558 | metirc_type = "IP" |
||
559 | entities, ids = init_data(connect, collection) |
||
560 | get_simple_index["metric_type"] = metirc_type |
||
561 | connect.create_index(collection, field_name, get_simple_index) |
||
562 | search_param = get_search_param(index_type) |
||
563 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type, search_params=search_param) |
||
564 | inside_vecs = entities[-1]["values"] |
||
565 | max_distance = 0 |
||
566 | for i in range(nb): |
||
567 | tmp_dis = ip(vecs[0], inside_vecs[i]) |
||
568 | if max_distance < tmp_dis: |
||
569 | max_distance = tmp_dis |
||
570 | res = connect.search(collection, query) |
||
571 | assert abs(res[0]._distances[0] - max_distance) <= gen_inaccuracy(res[0]._distances[0]) |
||
572 | |||
573 | # TODO: |
||
574 | def _test_search_distance_jaccard_flat_index(self, connect, binary_collection): |
||
575 | ''' |
||
576 | target: search binary_collection, and check the result: distance |
||
577 | method: compare the return distance value with value computed with Inner product |
||
578 | expected: the return distance equals to the computed value |
||
579 | ''' |
||
580 | # from scipy.spatial import distance |
||
581 | nprobe = 512 |
||
582 | int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) |
||
583 | query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
584 | distance_0 = jaccard(query_int_vectors[0], int_vectors[0]) |
||
585 | distance_1 = jaccard(query_int_vectors[0], int_vectors[1]) |
||
586 | res = connect.search(binary_collection, query_entities) |
||
587 | assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon |
||
588 | |||
589 | def _test_search_distance_hamming_flat_index(self, connect, binary_collection): |
||
590 | ''' |
||
591 | target: search binary_collection, and check the result: distance |
||
592 | method: compare the return distance value with value computed with Inner product |
||
593 | expected: the return distance equals to the computed value |
||
594 | ''' |
||
595 | # from scipy.spatial import distance |
||
596 | nprobe = 512 |
||
597 | int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) |
||
598 | query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
599 | distance_0 = hamming(query_int_vectors[0], int_vectors[0]) |
||
600 | distance_1 = hamming(query_int_vectors[0], int_vectors[1]) |
||
601 | res = connect.search(binary_collection, query_entities) |
||
602 | assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon |
||
603 | |||
604 | View Code Duplication | def _test_search_distance_substructure_flat_index(self, connect, binary_collection): |
|
605 | ''' |
||
606 | target: search binary_collection, and check the result: distance |
||
607 | method: compare the return distance value with value computed with Inner product |
||
608 | expected: the return distance equals to the computed value |
||
609 | ''' |
||
610 | # from scipy.spatial import distance |
||
611 | nprobe = 512 |
||
612 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
613 | index_type = "FLAT" |
||
614 | index_param = { |
||
615 | "nlist": 16384, |
||
616 | "metric_type": "SUBSTRUCTURE" |
||
617 | } |
||
618 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
619 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
620 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
621 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
622 | distance_0 = substructure(query_int_vectors[0], int_vectors[0]) |
||
623 | distance_1 = substructure(query_int_vectors[0], int_vectors[1]) |
||
624 | search_param = get_search_param(index_type) |
||
625 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
626 | logging.getLogger().info(status) |
||
627 | logging.getLogger().info(result) |
||
628 | assert len(result[0]) == 0 |
||
629 | |||
630 | View Code Duplication | def _test_search_distance_substructure_flat_index_B(self, connect, binary_collection): |
|
631 | ''' |
||
632 | target: search binary_collection, and check the result: distance |
||
633 | method: compare the return distance value with value computed with SUB |
||
634 | expected: the return distance equals to the computed value |
||
635 | ''' |
||
636 | # from scipy.spatial import distance |
||
637 | top_k = 3 |
||
638 | nprobe = 512 |
||
639 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
640 | index_type = "FLAT" |
||
641 | index_param = { |
||
642 | "nlist": 16384, |
||
643 | "metric_type": "SUBSTRUCTURE" |
||
644 | } |
||
645 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
646 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
647 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
648 | query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2) |
||
649 | search_param = get_search_param(index_type) |
||
650 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
651 | logging.getLogger().info(status) |
||
652 | logging.getLogger().info(result) |
||
653 | assert len(result[0]) == 1 |
||
654 | assert len(result[1]) == 1 |
||
655 | assert result[0][0].distance <= epsilon |
||
656 | assert result[0][0].id == ids[0] |
||
657 | assert result[1][0].distance <= epsilon |
||
658 | assert result[1][0].id == ids[1] |
||
659 | |||
660 | View Code Duplication | def _test_search_distance_superstructure_flat_index(self, connect, binary_collection): |
|
661 | ''' |
||
662 | target: search binary_collection, and check the result: distance |
||
663 | method: compare the return distance value with value computed with Inner product |
||
664 | expected: the return distance equals to the computed value |
||
665 | ''' |
||
666 | # from scipy.spatial import distance |
||
667 | nprobe = 512 |
||
668 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
669 | index_type = "FLAT" |
||
670 | index_param = { |
||
671 | "nlist": 16384, |
||
672 | "metric_type": "SUBSTRUCTURE" |
||
673 | } |
||
674 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
675 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
676 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
677 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
678 | distance_0 = superstructure(query_int_vectors[0], int_vectors[0]) |
||
679 | distance_1 = superstructure(query_int_vectors[0], int_vectors[1]) |
||
680 | search_param = get_search_param(index_type) |
||
681 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
682 | logging.getLogger().info(status) |
||
683 | logging.getLogger().info(result) |
||
684 | assert len(result[0]) == 0 |
||
685 | |||
686 | View Code Duplication | def _test_search_distance_superstructure_flat_index_B(self, connect, binary_collection): |
|
687 | ''' |
||
688 | target: search binary_collection, and check the result: distance |
||
689 | method: compare the return distance value with value computed with SUPER |
||
690 | expected: the return distance equals to the computed value |
||
691 | ''' |
||
692 | # from scipy.spatial import distance |
||
693 | top_k = 3 |
||
694 | nprobe = 512 |
||
695 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
696 | index_type = "FLAT" |
||
697 | index_param = { |
||
698 | "nlist": 16384, |
||
699 | "metric_type": "SUBSTRUCTURE" |
||
700 | } |
||
701 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
702 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
703 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
704 | query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2) |
||
705 | search_param = get_search_param(index_type) |
||
706 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
707 | logging.getLogger().info(status) |
||
708 | logging.getLogger().info(result) |
||
709 | assert len(result[0]) == 2 |
||
710 | assert len(result[1]) == 2 |
||
711 | assert result[0][0].id in ids |
||
712 | assert result[0][0].distance <= epsilon |
||
713 | assert result[1][0].id in ids |
||
714 | assert result[1][0].distance <= epsilon |
||
715 | |||
716 | View Code Duplication | def _test_search_distance_tanimoto_flat_index(self, connect, binary_collection): |
|
717 | ''' |
||
718 | target: search binary_collection, and check the result: distance |
||
719 | method: compare the return distance value with value computed with Inner product |
||
720 | expected: the return distance equals to the computed value |
||
721 | ''' |
||
722 | # from scipy.spatial import distance |
||
723 | nprobe = 512 |
||
724 | int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2) |
||
725 | index_type = "FLAT" |
||
726 | index_param = { |
||
727 | "nlist": 16384, |
||
728 | "metric_type": "TANIMOTO" |
||
729 | } |
||
730 | connect.create_index(binary_collection, binary_field_name, index_param) |
||
731 | logging.getLogger().info(connect.get_collection_info(binary_collection)) |
||
732 | logging.getLogger().info(connect.get_index_info(binary_collection)) |
||
733 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False) |
||
734 | distance_0 = tanimoto(query_int_vectors[0], int_vectors[0]) |
||
735 | distance_1 = tanimoto(query_int_vectors[0], int_vectors[1]) |
||
736 | search_param = get_search_param(index_type) |
||
737 | status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param) |
||
738 | logging.getLogger().info(status) |
||
739 | logging.getLogger().info(result) |
||
740 | assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon |
||
741 | |||
742 | @pytest.mark.timeout(30) |
||
743 | def test_search_concurrent_multithreads(self, connect, args): |
||
744 | ''' |
||
745 | target: test concurrent search with multiprocessess |
||
746 | method: search with 10 processes, each process uses dependent connection |
||
747 | expected: status ok and the returned vectors should be query_records |
||
748 | ''' |
||
749 | nb = 100 |
||
750 | top_k = 10 |
||
751 | threads_num = 4 |
||
752 | threads = [] |
||
753 | collection = gen_unique_str(collection_id) |
||
754 | uri = "tcp://%s:%s" % (args["ip"], args["port"]) |
||
755 | # create collection |
||
756 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
757 | milvus.create_collection(collection, default_fields) |
||
758 | entities, ids = init_data(milvus, collection) |
||
759 | def search(milvus): |
||
760 | res = connect.search(collection, default_query) |
||
761 | assert len(res) == 1 |
||
762 | assert res[0]._entities[0].id in ids |
||
763 | assert res[0]._distances[0] < epsilon |
||
764 | for i in range(threads_num): |
||
765 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
766 | t = threading.Thread(target=search, args=(milvus, )) |
||
767 | threads.append(t) |
||
768 | t.start() |
||
769 | time.sleep(0.2) |
||
770 | for t in threads: |
||
771 | t.join() |
||
772 | |||
773 | @pytest.mark.timeout(30) |
||
774 | def test_search_concurrent_multithreads_single_connection(self, connect, args): |
||
775 | ''' |
||
776 | target: test concurrent search with multiprocessess |
||
777 | method: search with 10 processes, each process uses dependent connection |
||
778 | expected: status ok and the returned vectors should be query_records |
||
779 | ''' |
||
780 | nb = 100 |
||
781 | top_k = 10 |
||
782 | threads_num = 4 |
||
783 | threads = [] |
||
784 | collection = gen_unique_str(collection_id) |
||
785 | uri = "tcp://%s:%s" % (args["ip"], args["port"]) |
||
786 | # create collection |
||
787 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
788 | milvus.create_collection(collection, default_fields) |
||
789 | entities, ids = init_data(milvus, collection) |
||
790 | def search(milvus): |
||
791 | res = connect.search(collection, default_query) |
||
792 | assert len(res) == 1 |
||
793 | assert res[0]._entities[0].id in ids |
||
794 | assert res[0]._distances[0] < epsilon |
||
795 | for i in range(threads_num): |
||
796 | t = threading.Thread(target=search, args=(milvus, )) |
||
797 | threads.append(t) |
||
798 | t.start() |
||
799 | time.sleep(0.2) |
||
800 | for t in threads: |
||
801 | t.join() |
||
802 | |||
803 | def test_search_multi_collections(self, connect, args): |
||
804 | ''' |
||
805 | target: test search multi collections of L2 |
||
806 | method: add vectors into 10 collections, and search |
||
807 | expected: search status ok, the length of result |
||
808 | ''' |
||
809 | num = 10 |
||
810 | top_k = 10 |
||
811 | nq = 20 |
||
812 | for i in range(num): |
||
813 | collection = gen_unique_str(collection_id+str(i)) |
||
814 | connect.create_collection(collection, default_fields) |
||
815 | entities, ids = init_data(connect, collection) |
||
816 | assert len(ids) == nb |
||
817 | query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) |
||
818 | res = connect.search(collection, query) |
||
819 | assert len(res) == nq |
||
820 | for i in range(nq): |
||
821 | assert check_id_result(res[i], ids[i]) |
||
822 | assert res[i]._distances[0] < epsilon |
||
823 | assert res[i]._distances[1] > epsilon |
||
824 | |||
825 | |||
826 | class TestSearchDSL(object): |
||
827 | |||
828 | """ |
||
829 | ****************************************************************** |
||
830 | # The following cases are used to build invalid query expr |
||
831 | ****************************************************************** |
||
832 | """ |
||
833 | |||
834 | # TODO: assert exception |
||
835 | def test_query_no_must(self, connect, collection): |
||
836 | ''' |
||
837 | method: build query without must expr |
||
838 | expected: error raised |
||
839 | ''' |
||
840 | # entities, ids = init_data(connect, collection) |
||
841 | query = update_query_expr(default_query, keep_old=False) |
||
842 | with pytest.raises(Exception) as e: |
||
843 | res = connect.search(collection, query) |
||
844 | |||
845 | # TODO: |
||
846 | def test_query_no_vector_term_only(self, connect, collection): |
||
847 | ''' |
||
848 | method: build query without must expr |
||
849 | expected: error raised |
||
850 | ''' |
||
851 | # entities, ids = init_data(connect, collection) |
||
852 | expr = { |
||
853 | "must": [gen_default_term_expr] |
||
854 | } |
||
855 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
856 | with pytest.raises(Exception) as e: |
||
857 | res = connect.search(collection, query) |
||
858 | |||
859 | def test_query_wrong_format(self, connect, collection): |
||
860 | ''' |
||
861 | method: build query without must expr, with wrong expr name |
||
862 | expected: error raised |
||
863 | ''' |
||
864 | # entities, ids = init_data(connect, collection) |
||
865 | expr = { |
||
866 | "must1": [gen_default_term_expr] |
||
867 | } |
||
868 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
869 | with pytest.raises(Exception) as e: |
||
870 | res = connect.search(collection, query) |
||
871 | |||
872 | def test_query_empty(self, connect, collection): |
||
873 | ''' |
||
874 | method: search with empty query |
||
875 | expected: error raised |
||
876 | ''' |
||
877 | query = {} |
||
878 | with pytest.raises(Exception) as e: |
||
879 | res = connect.search(collection, query) |
||
880 | |||
881 | def test_query_with_wrong_format_term(self, connect, collection): |
||
882 | ''' |
||
883 | method: build query with wrong term expr |
||
884 | expected: error raised |
||
885 | ''' |
||
886 | expr = gen_default_term_expr |
||
887 | expr["term"] = 1 |
||
888 | query = update_query_expr(default_query, expr=expr) |
||
889 | with pytest.raises(Exception) as e: |
||
890 | res = connect.search(collection, query) |
||
891 | |||
892 | |||
893 | """ |
||
894 | ****************************************************************** |
||
895 | # The following cases are used to build valid query expr |
||
896 | ****************************************************************** |
||
897 | """ |
||
898 | def test_query_term_value_not_in(self, connect, collection): |
||
899 | ''' |
||
900 | method: build query with vector and term expr, with no term can be filtered |
||
901 | expected: filter pass |
||
902 | ''' |
||
903 | entities, ids = init_data(connect, collection) |
||
904 | expr = gen_default_term_expr(values=[100000]) |
||
905 | query = update_query_expr(default_query, expr=expr) |
||
906 | res = connect.search(collection, query) |
||
907 | # TODO: |
||
908 | |||
909 | def test_query_term_value_all_in(self, connect, collection): |
||
910 | ''' |
||
911 | method: build query with vector and term expr, with all term can be filtered |
||
912 | expected: filter pass |
||
913 | ''' |
||
914 | entities, ids = init_data(connect, collection) |
||
915 | expr = gen_default_term_expr(values=[1]) |
||
916 | query = update_query_expr(default_query, expr=expr) |
||
917 | res = connect.search(collection, query) |
||
918 | # TODO: |
||
919 | |||
920 | def test_query_term_values_not_in(self, connect, collection): |
||
921 | ''' |
||
922 | method: build query with vector and term expr, with no term can be filtered |
||
923 | expected: filter pass |
||
924 | ''' |
||
925 | entities, ids = init_data(connect, collection) |
||
926 | expr = gen_default_term_expr(values=[i for i in range(100000, 100010)]) |
||
927 | query = update_query_expr(default_query, expr=expr) |
||
928 | res = connect.search(collection, query) |
||
929 | # TODO: |
||
930 | |||
931 | def test_query_term_values_all_in(self, connect, collection): |
||
932 | ''' |
||
933 | method: build query with vector and term expr, with all term can be filtered |
||
934 | expected: filter pass |
||
935 | ''' |
||
936 | entities, ids = init_data(connect, collection) |
||
937 | expr = gen_default_term_expr() |
||
938 | query = update_query_expr(default_query, expr=expr) |
||
939 | res = connect.search(collection, query) |
||
940 | # TODO: |
||
941 | |||
942 | def test_query_term_values_parts_in(self, connect, collection): |
||
943 | ''' |
||
944 | method: build query with vector and term expr, with parts of term can be filtered |
||
945 | expected: filter pass |
||
946 | ''' |
||
947 | entities, ids = init_data(connect, collection) |
||
948 | expr = gen_default_term_expr(values=[i for i in range(nb/2, nb+nb/2)]) |
||
949 | query = update_query_expr(default_query, expr=expr) |
||
950 | res = connect.search(collection, query) |
||
951 | # TODO: |
||
952 | |||
953 | def test_query_term_values_repeat(self, connect, collection): |
||
954 | ''' |
||
955 | method: build query with vector and term expr, with the same values |
||
956 | expected: filter pass |
||
957 | ''' |
||
958 | entities, ids = init_data(connect, collection) |
||
959 | expr = gen_default_term_expr(values=[1 for i in range(1, nb)]) |
||
960 | query = update_query_expr(default_query, expr=expr) |
||
961 | res = connect.search(collection, query) |
||
962 | # TODO: |
||
963 | |||
964 | |||
965 | class TestSearchDSLBools(object): |
||
966 | |||
967 | """ |
||
968 | ****************************************************************** |
||
969 | # The following cases are used to build invalid query expr |
||
970 | ****************************************************************** |
||
971 | """ |
||
972 | def test_query_no_bool(self, connect, collection): |
||
973 | ''' |
||
974 | method: build query without bool expr |
||
975 | expected: error raised |
||
976 | ''' |
||
977 | expr = {"bool1": {}} |
||
978 | with pytest.raises(Exception) as e: |
||
979 | res = connect.search(collection, query) |
||
980 | |||
981 | def test_query_should_only_term(self, connect, collection): |
||
982 | ''' |
||
983 | method: build query without must, with should.term instead |
||
984 | expected: error raised |
||
985 | ''' |
||
986 | expr = {"should": gen_default_term_expr} |
||
987 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
988 | with pytest.raises(Exception) as e: |
||
989 | res = connect.search(collection, query) |
||
990 | |||
991 | def test_query_should_only_vector(self, connect, collection): |
||
992 | ''' |
||
993 | method: build query without must, with should.vector instead |
||
994 | expected: error raised |
||
995 | ''' |
||
996 | expr = {"should": default_query["bool"]["must"]} |
||
997 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
998 | with pytest.raises(Exception) as e: |
||
999 | res = connect.search(collection, query) |
||
1000 | |||
1001 | def test_query_must_not_only_term(self, connect, collection): |
||
1002 | ''' |
||
1003 | method: build query without must, with must_not.term instead |
||
1004 | expected: error raised |
||
1005 | ''' |
||
1006 | expr = {"must_not": gen_default_term_expr} |
||
1007 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
1008 | with pytest.raises(Exception) as e: |
||
1009 | res = connect.search(collection, query) |
||
1010 | |||
1011 | def test_query_must_not_vector(self, connect, collection): |
||
1012 | ''' |
||
1013 | method: build query without must, with must_not.vector instead |
||
1014 | expected: error raised |
||
1015 | ''' |
||
1016 | expr = {"must_not": default_query["bool"]["must"]} |
||
1017 | query = update_query_expr(default_query, keep_old=False, expr=expr) |
||
1018 | with pytest.raises(Exception) as e: |
||
1019 | res = connect.search(collection, query) |
||
1020 | |||
1021 | def test_query_must_should(self, connect, collection): |
||
1022 | ''' |
||
1023 | method: build query must, and with should.term |
||
1024 | expected: error raised |
||
1025 | ''' |
||
1026 | expr = {"should": gen_default_term_expr} |
||
1027 | query = update_query_expr(default_query, keep_old=True, expr=expr) |
||
1028 | with pytest.raises(Exception) as e: |
||
1029 | res = connect.search(collection, query) |
||
1030 | |||
1031 | |||
1032 | """ |
||
1033 | ****************************************************************** |
||
1034 | # The following cases are used to test `search` function |
||
1035 | # with invalid collection_name, or invalid query expr |
||
1036 | ****************************************************************** |
||
1037 | """ |
||
1038 | |||
1039 | class TestSearchInvalid(object): |
||
1040 | |||
1041 | """ |
||
1042 | Test search collection with invalid collection names |
||
1043 | """ |
||
1044 | @pytest.fixture( |
||
1045 | scope="function", |
||
1046 | params=gen_invalid_strs() |
||
1047 | ) |
||
1048 | def get_collection_name(self, request): |
||
1049 | yield request.param |
||
1050 | |||
1051 | @pytest.fixture( |
||
1052 | scope="function", |
||
1053 | params=gen_invalid_strs() |
||
1054 | ) |
||
1055 | def get_invalid_tag(self, request): |
||
1056 | yield request.param |
||
1057 | |||
1058 | @pytest.fixture( |
||
1059 | scope="function", |
||
1060 | params=gen_invalid_strs() |
||
1061 | ) |
||
1062 | def get_invalid_field(self, request): |
||
1063 | yield request.param |
||
1064 | |||
1065 | @pytest.fixture( |
||
1066 | scope="function", |
||
1067 | params=gen_simple_index() |
||
1068 | ) |
||
1069 | def get_simple_index(self, request, connect): |
||
1070 | if str(connect._cmd("mode")) == "CPU": |
||
1071 | if request.param["index_type"] in index_cpu_not_support(): |
||
1072 | pytest.skip("sq8h not support in CPU mode") |
||
1073 | return request.param |
||
1074 | |||
1075 | @pytest.mark.level(2) |
||
1076 | def test_search_with_invalid_collection(self, connect, get_collection_name): |
||
1077 | collection_name = get_collection_name |
||
1078 | with pytest.raises(Exception) as e: |
||
1079 | res = connect.search(collection_name, default_query) |
||
1080 | |||
1081 | @pytest.mark.level(1) |
||
1082 | def test_search_with_invalid_tag(self, connect, collection): |
||
1083 | tag = " " |
||
1084 | with pytest.raises(Exception) as e: |
||
1085 | res = connect.search(collection, default_query, partition_tags=tag) |
||
1086 | |||
1087 | @pytest.mark.level(2) |
||
1088 | def test_search_with_invalid_field_name(self, connect, collection, get_invalid_field): |
||
1089 | fields = [get_invalid_field] |
||
1090 | with pytest.raises(Exception) as e: |
||
1091 | res = connect.search(collection, default_query, fields=fields) |
||
1092 | |||
1093 | @pytest.mark.level(1) |
||
1094 | def test_search_with_not_existed_field_name(self, connect, collection): |
||
1095 | fields = [gen_unique_str("field_name")] |
||
1096 | with pytest.raises(Exception) as e: |
||
1097 | res = connect.search(collection, default_query, fields=fields) |
||
1098 | |||
1099 | """ |
||
1100 | Test search collection with invalid query |
||
1101 | """ |
||
1102 | @pytest.fixture( |
||
1103 | scope="function", |
||
1104 | params=gen_invalid_ints() |
||
1105 | ) |
||
1106 | def get_top_k(self, request): |
||
1107 | yield request.param |
||
1108 | |||
1109 | @pytest.mark.level(1) |
||
1110 | def test_search_with_invalid_top_k(self, connect, collection, get_top_k): |
||
1111 | ''' |
||
1112 | target: test search fuction, with the wrong top_k |
||
1113 | method: search with top_k |
||
1114 | expected: raise an error, and the connection is normal |
||
1115 | ''' |
||
1116 | top_k = get_top_k |
||
1117 | default_query["bool"]["must"][0]["vector"][field_name]["topk"] = top_k |
||
1118 | with pytest.raises(Exception) as e: |
||
1119 | res = connect.search(collection, default_query) |
||
1120 | |||
1121 | """ |
||
1122 | Test search collection with invalid search params |
||
1123 | """ |
||
1124 | @pytest.fixture( |
||
1125 | scope="function", |
||
1126 | params=gen_invaild_search_params() |
||
1127 | ) |
||
1128 | def get_search_params(self, request): |
||
1129 | yield request.param |
||
1130 | |||
1131 | # TODO: This case can all pass, but it's too slow |
||
1132 | View Code Duplication | @pytest.mark.level(2) |
|
1133 | def _test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params): |
||
1134 | ''' |
||
1135 | target: test search fuction, with the wrong nprobe |
||
1136 | method: search with nprobe |
||
1137 | expected: raise an error, and the connection is normal |
||
1138 | ''' |
||
1139 | search_params = get_search_params |
||
1140 | index_type = get_simple_index["index_type"] |
||
1141 | entities, ids = init_data(connect, collection) |
||
1142 | connect.create_index(collection, field_name, get_simple_index) |
||
1143 | if search_params["index_type"] != index_type: |
||
1144 | pytest.skip("Skip case") |
||
1145 | query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params=search_params["search_params"]) |
||
1146 | with pytest.raises(Exception) as e: |
||
1147 | res = connect.search(collection, query) |
||
1148 | |||
1149 | View Code Duplication | def test_search_with_empty_params(self, connect, collection, args, get_simple_index): |
|
1150 | ''' |
||
1151 | target: test search fuction, with empty search params |
||
1152 | method: search with params |
||
1153 | expected: raise an error, and the connection is normal |
||
1154 | ''' |
||
1155 | index_type = get_simple_index["index_type"] |
||
1156 | if args["handler"] == "HTTP": |
||
1157 | pytest.skip("skip in http mode") |
||
1158 | if index_type == "FLAT": |
||
1159 | pytest.skip("skip in FLAT index") |
||
1160 | entities, ids = init_data(connect, collection) |
||
1161 | connect.create_index(collection, field_name, get_simple_index) |
||
1162 | query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params={}) |
||
1163 | with pytest.raises(Exception) as e: |
||
1164 | res = connect.search(collection, query) |
||
1165 | |||
1166 | |||
1167 | def check_id_result(result, id): |
||
1168 | limit_in = 5 |
||
1169 | ids = [entity.id for entity in result] |
||
1170 | if len(result) >= limit_in: |
||
1171 | return id in ids[:limit_in] |
||
1172 | else: |
||
1173 | return id in ids |
||
1174 |