Total Complexity | 136 |
Total Lines | 1363 |
Duplicated Lines | 28.03 % |
Changes | 0 |
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like test_search often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | import pdb |
||
2 | import struct |
||
3 | from random import sample |
||
4 | import threading |
||
5 | import datetime |
||
6 | import logging |
||
7 | from time import sleep |
||
8 | import concurrent.futures |
||
9 | from multiprocessing import Process |
||
10 | import pytest |
||
11 | import numpy |
||
12 | import sklearn.preprocessing |
||
13 | from milvus import IndexType, MetricType |
||
14 | from utils import * |
||
15 | |||
16 | dim = 128 |
||
17 | collection_id = "test_search" |
||
18 | add_interval_time = 2 |
||
19 | vectors = gen_vectors(6000, dim) |
||
20 | vectors = sklearn.preprocessing.normalize(vectors, axis=1, norm='l2') |
||
21 | vectors = vectors.tolist() |
||
22 | top_k = 1 |
||
23 | nprobe = 1 |
||
24 | epsilon = 0.001 |
||
25 | tag = "1970-01-01" |
||
26 | raw_vectors, binary_vectors = gen_binary_vectors(6000, dim) |
||
27 | |||
28 | |||
29 | class TestSearchBase: |
||
30 | def init_data(self, connect, collection, nb=6000, dim=dim, partition_tags=None): |
||
31 | ''' |
||
32 | Generate vectors and add it in collection, before search vectors |
||
33 | ''' |
||
34 | global vectors |
||
35 | if nb == 6000: |
||
36 | add_vectors = vectors |
||
37 | else: |
||
38 | add_vectors = gen_vectors(nb, dim) |
||
39 | add_vectors = sklearn.preprocessing.normalize(add_vectors, axis=1, norm='l2') |
||
40 | add_vectors = add_vectors.tolist() |
||
41 | if partition_tags is None: |
||
42 | status, ids = connect.insert(collection, add_vectors) |
||
43 | assert status.OK() |
||
44 | else: |
||
45 | status, ids = connect.insert(collection, add_vectors, partition_tag=partition_tags) |
||
46 | assert status.OK() |
||
47 | connect.flush([collection]) |
||
48 | return add_vectors, ids |
||
49 | |||
50 | def init_binary_data(self, connect, collection, nb=6000, dim=dim, insert=True, partition_tags=None): |
||
51 | ''' |
||
52 | Generate vectors and add it in collection, before search vectors |
||
53 | ''' |
||
54 | ids = [] |
||
55 | global binary_vectors |
||
56 | global raw_vectors |
||
57 | if nb == 6000: |
||
58 | add_vectors = binary_vectors |
||
59 | add_raw_vectors = raw_vectors |
||
60 | else: |
||
61 | add_raw_vectors, add_vectors = gen_binary_vectors(nb, dim) |
||
62 | if insert is True: |
||
63 | if partition_tags is None: |
||
64 | status, ids = connect.insert(collection, add_vectors) |
||
65 | assert status.OK() |
||
66 | else: |
||
67 | status, ids = connect.insert(collection, add_vectors, partition_tag=partition_tags) |
||
68 | assert status.OK() |
||
69 | connect.flush([collection]) |
||
70 | return add_raw_vectors, add_vectors, ids |
||
71 | |||
72 | """ |
||
73 | generate valid create_index params |
||
74 | """ |
||
75 | |||
76 | @pytest.fixture( |
||
77 | scope="function", |
||
78 | params=gen_index() |
||
79 | ) |
||
80 | def get_index(self, request, connect): |
||
81 | if str(connect._cmd("mode")[1]) == "CPU": |
||
82 | if request.param["index_type"] == IndexType.IVF_SQ8H: |
||
83 | pytest.skip("sq8h not support in CPU mode") |
||
84 | return request.param |
||
85 | |||
86 | @pytest.fixture( |
||
87 | scope="function", |
||
88 | params=gen_simple_index() |
||
89 | ) |
||
90 | def get_simple_index(self, request, connect): |
||
91 | if str(connect._cmd("mode")[1]) == "CPU": |
||
92 | if request.param["index_type"] == IndexType.IVF_SQ8H: |
||
93 | pytest.skip("sq8h not support in CPU mode") |
||
94 | return request.param |
||
95 | |||
96 | View Code Duplication | @pytest.fixture( |
|
|
|||
97 | scope="function", |
||
98 | params=gen_simple_index() |
||
99 | ) |
||
100 | def get_jaccard_index(self, request, connect): |
||
101 | logging.getLogger().info(request.param) |
||
102 | if request.param["index_type"] == IndexType.IVFLAT or request.param["index_type"] == IndexType.FLAT: |
||
103 | return request.param |
||
104 | else: |
||
105 | pytest.skip("Skip index Temporary") |
||
106 | |||
107 | View Code Duplication | @pytest.fixture( |
|
108 | scope="function", |
||
109 | params=gen_simple_index() |
||
110 | ) |
||
111 | def get_hamming_index(self, request, connect): |
||
112 | logging.getLogger().info(request.param) |
||
113 | if request.param["index_type"] == IndexType.IVFLAT or request.param["index_type"] == IndexType.FLAT: |
||
114 | return request.param |
||
115 | else: |
||
116 | pytest.skip("Skip index Temporary") |
||
117 | |||
118 | @pytest.fixture( |
||
119 | scope="function", |
||
120 | params=gen_simple_index() |
||
121 | ) |
||
122 | def get_structure_index(self, request, connect): |
||
123 | logging.getLogger().info(request.param) |
||
124 | if request.param["index_type"] == IndexType.FLAT: |
||
125 | return request.param |
||
126 | else: |
||
127 | pytest.skip("Skip index Temporary") |
||
128 | |||
129 | """ |
||
130 | generate top-k params |
||
131 | """ |
||
132 | |||
133 | @pytest.fixture( |
||
134 | scope="function", |
||
135 | params=[1, 99, 1024, 2049, 16385] |
||
136 | ) |
||
137 | def get_top_k(self, request): |
||
138 | yield request.param |
||
139 | |||
140 | def test_search_top_k_flat_index(self, connect, collection, get_top_k): |
||
141 | ''' |
||
142 | target: test basic search fuction, all the search params is corrent, change top-k value |
||
143 | method: search with the given vectors, check the result |
||
144 | expected: search status ok, and the length of the result is top_k |
||
145 | ''' |
||
146 | vectors, ids = self.init_data(connect, collection) |
||
147 | query_vec = [vectors[0]] |
||
148 | top_k = get_top_k |
||
149 | status, result = connect.search(collection, top_k, query_vec) |
||
150 | if top_k <= 16384: |
||
151 | assert status.OK() |
||
152 | assert len(result[0]) == min(len(vectors), top_k) |
||
153 | assert result[0][0].distance <= epsilon |
||
154 | assert check_result(result[0], ids[0]) |
||
155 | else: |
||
156 | assert not status.OK() |
||
157 | |||
158 | def test_search_top_k_flat_index_metric_type(self, connect, collection): |
||
159 | ''' |
||
160 | target: test basic search fuction, all the search params is corrent, change top-k value |
||
161 | method: search with the given vectors, check the result |
||
162 | expected: search status ok, and the length of the result is top_k |
||
163 | ''' |
||
164 | vectors, ids = self.init_data(connect, collection) |
||
165 | query_vec = [vectors[0]] |
||
166 | status, result = connect.search(collection, top_k, query_vec, params={"metric_type": MetricType.IP.value}) |
||
167 | assert status.OK() |
||
168 | assert len(result[0]) == min(len(vectors), top_k) |
||
169 | assert result[0][0].distance >= 1 - epsilon |
||
170 | assert check_result(result[0], ids[0]) |
||
171 | |||
172 | @pytest.mark.level(2) |
||
173 | def test_search_top_k_flat_index_metric_type_invalid(self, connect, collection): |
||
174 | ''' |
||
175 | target: test basic search fuction, all the search params is corrent, change top-k value |
||
176 | method: search with the given vectors, check the result |
||
177 | expected: search status ok, and the length of the result is top_k |
||
178 | ''' |
||
179 | vectors, ids = self.init_data(connect, collection) |
||
180 | query_vec = [vectors[0]] |
||
181 | status, result = connect.search(collection, top_k, query_vec, params={"metric_type": MetricType.JACCARD.value}) |
||
182 | assert not status.OK() |
||
183 | |||
184 | def test_search_l2_index_params(self, connect, collection, get_simple_index): |
||
185 | ''' |
||
186 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
187 | method: search with the given vectors, check the result |
||
188 | expected: search status ok, and the length of the result is top_k |
||
189 | ''' |
||
190 | top_k = 10 |
||
191 | index_param = get_simple_index["index_param"] |
||
192 | index_type = get_simple_index["index_type"] |
||
193 | logging.getLogger().info(get_simple_index) |
||
194 | vectors, ids = self.init_data(connect, collection) |
||
195 | status = connect.create_index(collection, index_type, index_param) |
||
196 | query_vec = [vectors[0], vectors[1]] |
||
197 | search_param = get_search_param(index_type) |
||
198 | status, result = connect.search(collection, top_k, query_vec, params=search_param) |
||
199 | logging.getLogger().info(result) |
||
200 | if top_k <= 1024: |
||
201 | assert status.OK() |
||
202 | assert len(result[0]) == min(len(vectors), top_k) |
||
203 | assert check_result(result[0], ids[0]) |
||
204 | assert result[0][0].distance < result[0][1].distance |
||
205 | assert result[1][0].distance < result[1][1].distance |
||
206 | else: |
||
207 | assert not status.OK() |
||
208 | |||
209 | def test_search_l2_large_nq_index_params(self, connect, collection, get_simple_index): |
||
210 | ''' |
||
211 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
212 | method: search with the given vectors, check the result |
||
213 | expected: search status ok, and the length of the result is top_k |
||
214 | ''' |
||
215 | top_k = 10 |
||
216 | index_param = get_simple_index["index_param"] |
||
217 | index_type = get_simple_index["index_type"] |
||
218 | logging.getLogger().info(get_simple_index) |
||
219 | if index_type == IndexType.IVF_PQ: |
||
220 | pytest.skip("Skip PQ") |
||
221 | |||
222 | vectors, ids = self.init_data(connect, collection) |
||
223 | status = connect.create_index(collection, index_type, index_param) |
||
224 | query_vec = vectors[:1000] |
||
225 | search_param = get_search_param(index_type) |
||
226 | status, result = connect.search(collection, top_k, query_vec, params=search_param) |
||
227 | logging.getLogger().info(result) |
||
228 | assert status.OK() |
||
229 | assert len(result[0]) == min(len(vectors), top_k) |
||
230 | assert check_result(result[0], ids[0]) |
||
231 | assert result[0][0].distance <= epsilon |
||
232 | |||
233 | def test_search_with_multi_partitions(self, connect, collection): |
||
234 | ''' |
||
235 | target: test search with multi partition which contains default tag and other tags |
||
236 | method: insert vectors into e partition and search with partitions [_default, tag] |
||
237 | expected: search result is correct |
||
238 | ''' |
||
239 | connect.create_partition(collection, tag) |
||
240 | vectors, ids = self.init_data(connect, collection, nb=10, partition_tags=tag) |
||
241 | query_vec = [vectors[0]] |
||
242 | search_param = get_search_param(IndexType.FLAT) |
||
243 | status, result = connect.search(collection, top_k, query_vec, partition_tags=["_default", tag], |
||
244 | params=search_param) |
||
245 | assert status.OK() |
||
246 | logging.getLogger().info(result) |
||
247 | assert len(result[0]) == min(len(vectors), top_k) |
||
248 | assert check_result(result[0], ids[0]) |
||
249 | assert result[0][0].distance <= epsilon |
||
250 | |||
251 | def test_search_l2_index_params_partition(self, connect, collection, get_simple_index): |
||
252 | ''' |
||
253 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
254 | method: add vectors into collection, search with the given vectors, check the result |
||
255 | expected: search status ok, and the length of the result is top_k, search collection with partition tag return empty |
||
256 | ''' |
||
257 | top_k = 10 |
||
258 | index_param = get_simple_index["index_param"] |
||
259 | index_type = get_simple_index["index_type"] |
||
260 | logging.getLogger().info(get_simple_index) |
||
261 | if index_type == IndexType.IVF_PQ: |
||
262 | pytest.skip("Skip PQ") |
||
263 | status = connect.create_partition(collection, tag) |
||
264 | vectors, ids = self.init_data(connect, collection) |
||
265 | status = connect.create_index(collection, index_type, index_param) |
||
266 | query_vec = [vectors[0]] |
||
267 | search_param = get_search_param(index_type) |
||
268 | status, result = connect.search(collection, top_k, query_vec, params=search_param) |
||
269 | logging.getLogger().info(result) |
||
270 | assert status.OK() |
||
271 | assert len(result[0]) == min(len(vectors), top_k) |
||
272 | assert check_result(result[0], ids[0]) |
||
273 | assert result[0][0].distance <= epsilon |
||
274 | status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag], params=search_param) |
||
275 | logging.getLogger().info(result) |
||
276 | assert status.OK() |
||
277 | assert len(result) == 0 |
||
278 | |||
279 | def test_search_l2_index_params_partition_A(self, connect, collection, get_simple_index): |
||
280 | ''' |
||
281 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
282 | method: search partition with the given vectors, check the result |
||
283 | expected: search status ok, and the length of the result is 0 |
||
284 | ''' |
||
285 | top_k = 10 |
||
286 | index_param = get_simple_index["index_param"] |
||
287 | index_type = get_simple_index["index_type"] |
||
288 | logging.getLogger().info(get_simple_index) |
||
289 | if index_type == IndexType.IVF_PQ: |
||
290 | pytest.skip("Skip PQ") |
||
291 | |||
292 | status = connect.create_partition(collection, tag) |
||
293 | vectors, ids = self.init_data(connect, collection) |
||
294 | status = connect.create_index(collection, index_type, index_param) |
||
295 | query_vec = [vectors[0]] |
||
296 | search_param = get_search_param(index_type) |
||
297 | status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag], params=search_param) |
||
298 | logging.getLogger().info(result) |
||
299 | assert status.OK() |
||
300 | assert len(result) == 0 |
||
301 | |||
302 | def test_search_l2_index_params_partition_B(self, connect, collection, get_simple_index): |
||
303 | ''' |
||
304 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
305 | method: search with the given vectors, check the result |
||
306 | expected: search status ok, and the length of the result is top_k |
||
307 | ''' |
||
308 | top_k = 10 |
||
309 | index_param = get_simple_index["index_param"] |
||
310 | index_type = get_simple_index["index_type"] |
||
311 | logging.getLogger().info(get_simple_index) |
||
312 | if index_type == IndexType.IVF_PQ: |
||
313 | pytest.skip("Skip PQ") |
||
314 | status = connect.create_partition(collection, tag) |
||
315 | vectors, ids = self.init_data(connect, collection, partition_tags=tag) |
||
316 | status = connect.create_index(collection, index_type, index_param) |
||
317 | query_vec = [vectors[0]] |
||
318 | search_param = get_search_param(index_type) |
||
319 | status, result = connect.search(collection, top_k, query_vec, params=search_param) |
||
320 | logging.getLogger().info(result) |
||
321 | assert status.OK() |
||
322 | assert len(result[0]) == min(len(vectors), top_k) |
||
323 | assert check_result(result[0], ids[0]) |
||
324 | assert result[0][0].distance <= epsilon |
||
325 | status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag], params=search_param) |
||
326 | logging.getLogger().info(result) |
||
327 | assert status.OK() |
||
328 | assert len(result[0]) == min(len(vectors), top_k) |
||
329 | assert check_result(result[0], ids[0]) |
||
330 | assert result[0][0].distance <= epsilon |
||
331 | |||
332 | def test_search_l2_index_params_partition_C(self, connect, collection, get_simple_index): |
||
333 | ''' |
||
334 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
335 | method: search with the given vectors and tags (one of the tags not existed in collection), check the result |
||
336 | expected: search status ok, and the length of the result is top_k |
||
337 | ''' |
||
338 | index_param = get_simple_index["index_param"] |
||
339 | index_type = get_simple_index["index_type"] |
||
340 | logging.getLogger().info(get_simple_index) |
||
341 | if index_type == IndexType.IVF_PQ: |
||
342 | pytest.skip("Skip PQ") |
||
343 | status = connect.create_partition(collection, tag) |
||
344 | vectors, ids = self.init_data(connect, collection, partition_tags=tag) |
||
345 | status = connect.create_index(collection, index_type, index_param) |
||
346 | query_vec = [vectors[0]] |
||
347 | top_k = 10 |
||
348 | search_param = get_search_param(index_type) |
||
349 | status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, "new_tag"], |
||
350 | params=search_param) |
||
351 | logging.getLogger().info(result) |
||
352 | assert status.OK() |
||
353 | assert len(result[0]) == min(len(vectors), top_k) |
||
354 | assert check_result(result[0], ids[0]) |
||
355 | assert result[0][0].distance <= epsilon |
||
356 | |||
357 | @pytest.mark.level(2) |
||
358 | def test_search_l2_index_params_partition_D(self, connect, collection, get_simple_index): |
||
359 | ''' |
||
360 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
361 | method: search with the given vectors and tag (tag name not existed in collection), check the result |
||
362 | expected: search status ok, and the length of the result is top_k |
||
363 | ''' |
||
364 | index_param = get_simple_index["index_param"] |
||
365 | index_type = get_simple_index["index_type"] |
||
366 | logging.getLogger().info(get_simple_index) |
||
367 | status = connect.create_partition(collection, tag) |
||
368 | vectors, ids = self.init_data(connect, collection, partition_tags=tag) |
||
369 | status = connect.create_index(collection, index_type, index_param) |
||
370 | query_vec = [vectors[0]] |
||
371 | top_k = 10 |
||
372 | search_param = get_search_param(index_type) |
||
373 | status, result = connect.search(collection, top_k, query_vec, partition_tags=["new_tag"], params=search_param) |
||
374 | logging.getLogger().info(result) |
||
375 | assert not status.OK() |
||
376 | |||
377 | @pytest.mark.level(2) |
||
378 | def test_search_l2_index_params_partition_E(self, connect, collection, get_simple_index): |
||
379 | ''' |
||
380 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
381 | method: search collection with the given vectors and tags, check the result |
||
382 | expected: search status ok, and the length of the result is top_k |
||
383 | ''' |
||
384 | top_k = 10 |
||
385 | new_tag = "new_tag" |
||
386 | index_type = get_simple_index["index_type"] |
||
387 | index_param = get_simple_index["index_param"] |
||
388 | if index_type == IndexType.IVF_PQ: |
||
389 | pytest.skip("Skip PQ") |
||
390 | logging.getLogger().info(get_simple_index) |
||
391 | status = connect.create_partition(collection, tag) |
||
392 | status = connect.create_partition(collection, new_tag) |
||
393 | vectors, ids = self.init_data(connect, collection, partition_tags=tag) |
||
394 | new_vectors, new_ids = self.init_data(connect, collection, nb=6001, partition_tags=new_tag) |
||
395 | status = connect.create_index(collection, index_type, index_param) |
||
396 | query_vec = [vectors[0], new_vectors[0]] |
||
397 | search_param = get_search_param(index_type) |
||
398 | status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, new_tag], |
||
399 | params=search_param) |
||
400 | logging.getLogger().info(result) |
||
401 | assert status.OK() |
||
402 | assert len(result[0]) == min(len(vectors), top_k) |
||
403 | assert check_result(result[0], ids[0]) |
||
404 | assert check_result(result[1], new_ids[0]) |
||
405 | assert result[0][0].distance <= epsilon |
||
406 | assert result[1][0].distance <= epsilon |
||
407 | status, result = connect.search(collection, top_k, query_vec, partition_tags=[new_tag], params=search_param) |
||
408 | logging.getLogger().info(result) |
||
409 | assert status.OK() |
||
410 | assert len(result[0]) == min(len(vectors), top_k) |
||
411 | assert check_result(result[1], new_ids[0]) |
||
412 | assert result[1][0].distance <= epsilon |
||
413 | |||
414 | def test_search_l2_index_params_partition_F(self, connect, collection, get_simple_index): |
||
415 | ''' |
||
416 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
417 | method: search collection with the given vectors and tags with "re" expr, check the result |
||
418 | expected: search status ok, and the length of the result is top_k |
||
419 | ''' |
||
420 | tag = "atag" |
||
421 | new_tag = "new_tag" |
||
422 | index_param = get_simple_index["index_param"] |
||
423 | index_type = get_simple_index["index_type"] |
||
424 | logging.getLogger().info(get_simple_index) |
||
425 | if index_type == IndexType.IVF_PQ: |
||
426 | pytest.skip("Skip PQ") |
||
427 | status = connect.create_partition(collection, tag) |
||
428 | status = connect.create_partition(collection, new_tag) |
||
429 | vectors, ids = self.init_data(connect, collection, partition_tags=tag) |
||
430 | new_vectors, new_ids = self.init_data(connect, collection, nb=6001, partition_tags=new_tag) |
||
431 | status = connect.create_index(collection, index_type, index_param) |
||
432 | query_vec = [vectors[0], new_vectors[0]] |
||
433 | top_k = 10 |
||
434 | search_param = get_search_param(index_type) |
||
435 | status, result = connect.search(collection, top_k, query_vec, partition_tags=["new(.*)"], params=search_param) |
||
436 | logging.getLogger().info(result) |
||
437 | assert status.OK() |
||
438 | assert result[0][0].distance > epsilon |
||
439 | assert result[1][0].distance <= epsilon |
||
440 | status, result = connect.search(collection, top_k, query_vec, partition_tags=["(.*)tag"], params=search_param) |
||
441 | logging.getLogger().info(result) |
||
442 | assert status.OK() |
||
443 | assert result[0][0].distance <= epsilon |
||
444 | assert result[1][0].distance <= epsilon |
||
445 | |||
446 | @pytest.mark.level(2) |
||
447 | def test_search_ip_index_params(self, connect, ip_collection, get_simple_index): |
||
448 | ''' |
||
449 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
450 | method: search with the given vectors, check the result |
||
451 | expected: search status ok, and the length of the result is top_k |
||
452 | ''' |
||
453 | top_k = 10 |
||
454 | index_param = get_simple_index["index_param"] |
||
455 | index_type = get_simple_index["index_type"] |
||
456 | logging.getLogger().info(get_simple_index) |
||
457 | vectors, ids = self.init_data(connect, ip_collection) |
||
458 | status = connect.create_index(ip_collection, index_type, index_param) |
||
459 | query_vec = [vectors[0]] |
||
460 | search_param = get_search_param(index_type) |
||
461 | status, result = connect.search(ip_collection, top_k, query_vec, params=search_param) |
||
462 | logging.getLogger().info(result) |
||
463 | assert status.OK() |
||
464 | assert len(result[0]) == min(len(vectors), top_k) |
||
465 | assert check_result(result[0], ids[0]) |
||
466 | assert result[0][0].distance >= result[0][1].distance |
||
467 | |||
468 | def test_search_ip_large_nq_index_params(self, connect, ip_collection, get_simple_index): |
||
469 | ''' |
||
470 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
471 | method: search with the given vectors, check the result |
||
472 | expected: search status ok, and the length of the result is top_k |
||
473 | ''' |
||
474 | index_param = get_simple_index["index_param"] |
||
475 | index_type = get_simple_index["index_type"] |
||
476 | logging.getLogger().info(get_simple_index) |
||
477 | if index_type in [IndexType.RNSG, IndexType.IVF_PQ]: |
||
478 | pytest.skip("rnsg not support in ip, skip pq") |
||
479 | vectors, ids = self.init_data(connect, ip_collection) |
||
480 | status = connect.create_index(ip_collection, index_type, index_param) |
||
481 | query_vec = [] |
||
482 | for i in range(1200): |
||
483 | query_vec.append(vectors[i]) |
||
484 | top_k = 10 |
||
485 | search_param = get_search_param(index_type) |
||
486 | status, result = connect.search(ip_collection, top_k, query_vec, params=search_param) |
||
487 | logging.getLogger().info(result) |
||
488 | assert status.OK() |
||
489 | assert len(result[0]) == min(len(vectors), top_k) |
||
490 | assert check_result(result[0], ids[0]) |
||
491 | assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance) |
||
492 | |||
493 | @pytest.mark.level(2) |
||
494 | def test_search_ip_index_params_partition(self, connect, ip_collection, get_simple_index): |
||
495 | ''' |
||
496 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
497 | method: search with the given vectors, check the result |
||
498 | expected: search status ok, and the length of the result is top_k |
||
499 | ''' |
||
500 | top_k = 10 |
||
501 | index_param = get_simple_index["index_param"] |
||
502 | index_type = get_simple_index["index_type"] |
||
503 | logging.getLogger().info(index_param) |
||
504 | if index_type in [IndexType.RNSG, IndexType.IVF_PQ]: |
||
505 | pytest.skip("rnsg not support in ip, skip pq") |
||
506 | |||
507 | status = connect.create_partition(ip_collection, tag) |
||
508 | vectors, ids = self.init_data(connect, ip_collection) |
||
509 | status = connect.create_index(ip_collection, index_type, index_param) |
||
510 | query_vec = [vectors[0]] |
||
511 | search_param = get_search_param(index_type) |
||
512 | status, result = connect.search(ip_collection, top_k, query_vec, params=search_param) |
||
513 | logging.getLogger().info(result) |
||
514 | assert status.OK() |
||
515 | assert len(result[0]) == min(len(vectors), top_k) |
||
516 | assert check_result(result[0], ids[0]) |
||
517 | assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance) |
||
518 | status, result = connect.search(ip_collection, top_k, query_vec, partition_tags=[tag], params=search_param) |
||
519 | logging.getLogger().info(result) |
||
520 | assert status.OK() |
||
521 | assert len(result) == 0 |
||
522 | |||
523 | @pytest.mark.level(2) |
||
524 | def test_search_ip_index_params_partition_A(self, connect, ip_collection, get_simple_index): |
||
525 | ''' |
||
526 | target: test basic search fuction, all the search params is corrent, test all index params, and build |
||
527 | method: search with the given vectors and tag, check the result |
||
528 | expected: search status ok, and the length of the result is top_k |
||
529 | ''' |
||
530 | top_k = 10 |
||
531 | index_param = get_simple_index["index_param"] |
||
532 | index_type = get_simple_index["index_type"] |
||
533 | logging.getLogger().info(index_param) |
||
534 | if index_type in [IndexType.RNSG, IndexType.IVF_PQ]: |
||
535 | pytest.skip("rnsg not support in ip, skip pq") |
||
536 | |||
537 | status = connect.create_partition(ip_collection, tag) |
||
538 | vectors, ids = self.init_data(connect, ip_collection, partition_tags=tag) |
||
539 | status = connect.create_index(ip_collection, index_type, index_param) |
||
540 | query_vec = [vectors[0]] |
||
541 | search_param = get_search_param(index_type) |
||
542 | status, result = connect.search(ip_collection, top_k, query_vec, partition_tags=[tag], params=search_param) |
||
543 | logging.getLogger().info(result) |
||
544 | assert status.OK() |
||
545 | assert len(result[0]) == min(len(vectors), top_k) |
||
546 | assert check_result(result[0], ids[0]) |
||
547 | assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance) |
||
548 | |||
549 | @pytest.mark.level(2) |
||
550 | def test_search_vectors_without_connect(self, dis_connect, collection): |
||
551 | ''' |
||
552 | target: test search vectors without connection |
||
553 | method: use dis connected instance, call search method and check if search successfully |
||
554 | expected: raise exception |
||
555 | ''' |
||
556 | query_vectors = [vectors[0]] |
||
557 | nprobe = 1 |
||
558 | with pytest.raises(Exception) as e: |
||
559 | status, ids = dis_connect.search(collection, top_k, query_vectors) |
||
560 | |||
561 | def test_search_collection_name_not_existed(self, connect, collection): |
||
562 | ''' |
||
563 | target: search collection not existed |
||
564 | method: search with the random collection_name, which is not in db |
||
565 | expected: status not ok |
||
566 | ''' |
||
567 | collection_name = gen_unique_str("not_existed_collection") |
||
568 | nprobe = 1 |
||
569 | query_vecs = [vectors[0]] |
||
570 | status, result = connect.search(collection_name, top_k, query_vecs) |
||
571 | assert not status.OK() |
||
572 | |||
573 | def test_search_collection_name_None(self, connect, collection): |
||
574 | ''' |
||
575 | target: search collection that collection name is None |
||
576 | method: search with the collection_name: None |
||
577 | expected: status not ok |
||
578 | ''' |
||
579 | collection_name = None |
||
580 | nprobe = 1 |
||
581 | query_vecs = [vectors[0]] |
||
582 | with pytest.raises(Exception) as e: |
||
583 | status, result = connect.search(collection_name, top_k, query_vecs) |
||
584 | |||
585 | def test_search_top_k_query_records(self, connect, collection): |
||
586 | ''' |
||
587 | target: test search fuction, with search params: query_records |
||
588 | method: search with the given query_records, which are subarrays of the inserted vectors |
||
589 | expected: status ok and the returned vectors should be query_records |
||
590 | ''' |
||
591 | top_k = 10 |
||
592 | vectors, ids = self.init_data(connect, collection) |
||
593 | query_vecs = [vectors[0], vectors[55], vectors[99]] |
||
594 | status, result = connect.search(collection, top_k, query_vecs) |
||
595 | assert status.OK() |
||
596 | assert len(result) == len(query_vecs) |
||
597 | for i in range(len(query_vecs)): |
||
598 | assert len(result[i]) == top_k |
||
599 | assert result[i][0].distance <= epsilon |
||
600 | |||
601 | def test_search_distance_l2_flat_index(self, connect, collection): |
||
602 | ''' |
||
603 | target: search collection, and check the result: distance |
||
604 | method: compare the return distance value with value computed with Euclidean |
||
605 | expected: the return distance equals to the computed value |
||
606 | ''' |
||
607 | nb = 2 |
||
608 | vectors, ids = self.init_data(connect, collection, nb=nb) |
||
609 | query_vecs = [[0.50 for i in range(dim)]] |
||
610 | distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0])) |
||
611 | distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1])) |
||
612 | status, result = connect.search(collection, top_k, query_vecs) |
||
613 | assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy( |
||
614 | result[0][0].distance) |
||
615 | |||
616 | def test_search_distance_ip_flat_index(self, connect, ip_collection): |
||
617 | ''' |
||
618 | target: search ip_collection, and check the result: distance |
||
619 | method: compare the return distance value with value computed with Inner product |
||
620 | expected: the return distance equals to the computed value |
||
621 | ''' |
||
622 | nb = 2 |
||
623 | nprobe = 1 |
||
624 | vectors, ids = self.init_data(connect, ip_collection, nb=nb) |
||
625 | index_type = IndexType.FLAT |
||
626 | index_param = { |
||
627 | "nlist": 16384 |
||
628 | } |
||
629 | connect.create_index(ip_collection, index_type, index_param) |
||
630 | logging.getLogger().info(connect.get_index_info(ip_collection)) |
||
631 | query_vecs = [[0.50 for i in range(dim)]] |
||
632 | distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0])) |
||
633 | distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1])) |
||
634 | search_param = get_search_param(index_type) |
||
635 | status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param) |
||
636 | assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance) |
||
637 | |||
638 | View Code Duplication | def test_search_distance_jaccard_flat_index(self, connect, jac_collection): |
|
639 | ''' |
||
640 | target: search ip_collection, and check the result: distance |
||
641 | method: compare the return distance value with value computed with Inner product |
||
642 | expected: the return distance equals to the computed value |
||
643 | ''' |
||
644 | # from scipy.spatial import distance |
||
645 | nprobe = 512 |
||
646 | int_vectors, vectors, ids = self.init_binary_data(connect, jac_collection, nb=2) |
||
647 | index_type = IndexType.FLAT |
||
648 | index_param = { |
||
649 | "nlist": 16384 |
||
650 | } |
||
651 | connect.create_index(jac_collection, index_type, index_param) |
||
652 | logging.getLogger().info(connect.get_collection_info(jac_collection)) |
||
653 | logging.getLogger().info(connect.get_index_info(jac_collection)) |
||
654 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, jac_collection, nb=1, insert=False) |
||
655 | distance_0 = jaccard(query_int_vectors[0], int_vectors[0]) |
||
656 | distance_1 = jaccard(query_int_vectors[0], int_vectors[1]) |
||
657 | search_param = get_search_param(index_type) |
||
658 | status, result = connect.search(jac_collection, top_k, query_vecs, params=search_param) |
||
659 | logging.getLogger().info(status) |
||
660 | logging.getLogger().info(result) |
||
661 | assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon |
||
662 | |||
663 | def test_search_distance_jaccard_flat_index_metric_type(self, connect, jac_collection): |
||
664 | ''' |
||
665 | target: search ip_collection, and check the result: distance |
||
666 | method: compare the return distance value with value computed with HAMMING |
||
667 | expected: the return distance equals to the computed value |
||
668 | ''' |
||
669 | # from scipy.spatial import distance |
||
670 | nprobe = 512 |
||
671 | int_vectors, vectors, ids = self.init_binary_data(connect, jac_collection, nb=2) |
||
672 | index_type = IndexType.FLAT |
||
673 | index_param = { |
||
674 | "nlist": 16384 |
||
675 | } |
||
676 | connect.create_index(jac_collection, index_type, index_param) |
||
677 | logging.getLogger().info(connect.get_collection_info(jac_collection)) |
||
678 | logging.getLogger().info(connect.get_index_info(jac_collection)) |
||
679 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, jac_collection, nb=1, insert=False) |
||
680 | distance_0 = hamming(query_int_vectors[0], int_vectors[0]) |
||
681 | distance_1 = hamming(query_int_vectors[0], int_vectors[1]) |
||
682 | search_param = get_search_param(index_type) |
||
683 | search_param.update({"metric_type": MetricType.HAMMING.value}) |
||
684 | status, result = connect.search(jac_collection, top_k, query_vecs, params=search_param) |
||
685 | assert status.OK() |
||
686 | logging.getLogger().info(status) |
||
687 | logging.getLogger().info(result) |
||
688 | assert abs(result[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon |
||
689 | |||
690 | View Code Duplication | def test_search_distance_hamming_flat_index(self, connect, ham_collection): |
|
691 | ''' |
||
692 | target: search ip_collection, and check the result: distance |
||
693 | method: compare the return distance value with value computed with Inner product |
||
694 | expected: the return distance equals to the computed value |
||
695 | ''' |
||
696 | # from scipy.spatial import distance |
||
697 | nprobe = 512 |
||
698 | int_vectors, vectors, ids = self.init_binary_data(connect, ham_collection, nb=2) |
||
699 | index_type = IndexType.FLAT |
||
700 | index_param = { |
||
701 | "nlist": 16384 |
||
702 | } |
||
703 | connect.create_index(ham_collection, index_type, index_param) |
||
704 | logging.getLogger().info(connect.get_collection_info(ham_collection)) |
||
705 | logging.getLogger().info(connect.get_index_info(ham_collection)) |
||
706 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, ham_collection, nb=1, insert=False) |
||
707 | distance_0 = hamming(query_int_vectors[0], int_vectors[0]) |
||
708 | distance_1 = hamming(query_int_vectors[0], int_vectors[1]) |
||
709 | search_param = get_search_param(index_type) |
||
710 | status, result = connect.search(ham_collection, top_k, query_vecs, params=search_param) |
||
711 | logging.getLogger().info(status) |
||
712 | logging.getLogger().info(result) |
||
713 | assert abs(result[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon |
||
714 | |||
715 | def test_search_distance_substructure_flat_index(self, connect, substructure_collection): |
||
716 | ''' |
||
717 | target: search ip_collection, and check the result: distance |
||
718 | method: compare the return distance value with value computed with Inner product |
||
719 | expected: the return distance equals to the computed value |
||
720 | ''' |
||
721 | # from scipy.spatial import distance |
||
722 | nprobe = 512 |
||
723 | int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2) |
||
724 | index_type = IndexType.FLAT |
||
725 | index_param = { |
||
726 | "nlist": 16384 |
||
727 | } |
||
728 | connect.create_index(substructure_collection, index_type, index_param) |
||
729 | logging.getLogger().info(connect.get_collection_info(substructure_collection)) |
||
730 | logging.getLogger().info(connect.get_index_info(substructure_collection)) |
||
731 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1, |
||
732 | insert=False) |
||
733 | distance_0 = substructure(query_int_vectors[0], int_vectors[0]) |
||
734 | distance_1 = substructure(query_int_vectors[0], int_vectors[1]) |
||
735 | search_param = get_search_param(index_type) |
||
736 | status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param) |
||
737 | logging.getLogger().info(status) |
||
738 | logging.getLogger().info(result) |
||
739 | assert len(result[0]) == 0 |
||
740 | |||
741 | View Code Duplication | def test_search_distance_substructure_flat_index_B(self, connect, substructure_collection): |
|
742 | ''' |
||
743 | target: search ip_collection, and check the result: distance |
||
744 | method: compare the return distance value with value computed with SUB |
||
745 | expected: the return distance equals to the computed value |
||
746 | ''' |
||
747 | # from scipy.spatial import distance |
||
748 | top_k = 3 |
||
749 | nprobe = 512 |
||
750 | int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2) |
||
751 | index_type = IndexType.FLAT |
||
752 | index_param = { |
||
753 | "nlist": 16384 |
||
754 | } |
||
755 | connect.create_index(substructure_collection, index_type, index_param) |
||
756 | logging.getLogger().info(connect.get_collection_info(substructure_collection)) |
||
757 | logging.getLogger().info(connect.get_index_info(substructure_collection)) |
||
758 | query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2) |
||
759 | search_param = get_search_param(index_type) |
||
760 | status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param) |
||
761 | logging.getLogger().info(status) |
||
762 | logging.getLogger().info(result) |
||
763 | assert len(result[0]) == 1 |
||
764 | assert len(result[1]) == 1 |
||
765 | assert result[0][0].distance <= epsilon |
||
766 | assert result[0][0].id == ids[0] |
||
767 | assert result[1][0].distance <= epsilon |
||
768 | assert result[1][0].id == ids[1] |
||
769 | |||
770 | def test_search_distance_superstructure_flat_index(self, connect, superstructure_collection): |
||
771 | ''' |
||
772 | target: search ip_collection, and check the result: distance |
||
773 | method: compare the return distance value with value computed with Inner product |
||
774 | expected: the return distance equals to the computed value |
||
775 | ''' |
||
776 | # from scipy.spatial import distance |
||
777 | nprobe = 512 |
||
778 | int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2) |
||
779 | index_type = IndexType.FLAT |
||
780 | index_param = { |
||
781 | "nlist": 16384 |
||
782 | } |
||
783 | connect.create_index(superstructure_collection, index_type, index_param) |
||
784 | logging.getLogger().info(connect.get_collection_info(superstructure_collection)) |
||
785 | logging.getLogger().info(connect.get_index_info(superstructure_collection)) |
||
786 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1, |
||
787 | insert=False) |
||
788 | distance_0 = superstructure(query_int_vectors[0], int_vectors[0]) |
||
789 | distance_1 = superstructure(query_int_vectors[0], int_vectors[1]) |
||
790 | search_param = get_search_param(index_type) |
||
791 | status, result = connect.search(superstructure_collection, top_k, query_vecs, params=search_param) |
||
792 | logging.getLogger().info(status) |
||
793 | logging.getLogger().info(result) |
||
794 | assert len(result[0]) == 0 |
||
795 | |||
796 | View Code Duplication | def test_search_distance_superstructure_flat_index_B(self, connect, superstructure_collection): |
|
797 | ''' |
||
798 | target: search ip_collection, and check the result: distance |
||
799 | method: compare the return distance value with value computed with SUPER |
||
800 | expected: the return distance equals to the computed value |
||
801 | ''' |
||
802 | # from scipy.spatial import distance |
||
803 | top_k = 3 |
||
804 | nprobe = 512 |
||
805 | int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2) |
||
806 | index_type = IndexType.FLAT |
||
807 | index_param = { |
||
808 | "nlist": 16384 |
||
809 | } |
||
810 | connect.create_index(superstructure_collection, index_type, index_param) |
||
811 | logging.getLogger().info(connect.get_collection_info(superstructure_collection)) |
||
812 | logging.getLogger().info(connect.get_index_info(superstructure_collection)) |
||
813 | query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2) |
||
814 | search_param = get_search_param(index_type) |
||
815 | status, result = connect.search(superstructure_collection, top_k, query_vecs, params=search_param) |
||
816 | logging.getLogger().info(status) |
||
817 | logging.getLogger().info(result) |
||
818 | assert len(result[0]) == 2 |
||
819 | assert len(result[1]) == 2 |
||
820 | assert result[0][0].id in ids |
||
821 | assert result[0][0].distance <= epsilon |
||
822 | assert result[1][0].id in ids |
||
823 | assert result[1][0].distance <= epsilon |
||
824 | |||
825 | View Code Duplication | def test_search_distance_tanimoto_flat_index(self, connect, tanimoto_collection): |
|
826 | ''' |
||
827 | target: search ip_collection, and check the result: distance |
||
828 | method: compare the return distance value with value computed with Inner product |
||
829 | expected: the return distance equals to the computed value |
||
830 | ''' |
||
831 | # from scipy.spatial import distance |
||
832 | nprobe = 512 |
||
833 | int_vectors, vectors, ids = self.init_binary_data(connect, tanimoto_collection, nb=2) |
||
834 | index_type = IndexType.FLAT |
||
835 | index_param = { |
||
836 | "nlist": 16384 |
||
837 | } |
||
838 | connect.create_index(tanimoto_collection, index_type, index_param) |
||
839 | logging.getLogger().info(connect.get_collection_info(tanimoto_collection)) |
||
840 | logging.getLogger().info(connect.get_index_info(tanimoto_collection)) |
||
841 | query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, tanimoto_collection, nb=1, insert=False) |
||
842 | distance_0 = tanimoto(query_int_vectors[0], int_vectors[0]) |
||
843 | distance_1 = tanimoto(query_int_vectors[0], int_vectors[1]) |
||
844 | search_param = get_search_param(index_type) |
||
845 | status, result = connect.search(tanimoto_collection, top_k, query_vecs, params=search_param) |
||
846 | logging.getLogger().info(status) |
||
847 | logging.getLogger().info(result) |
||
848 | assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon |
||
849 | |||
850 | def test_search_distance_ip_index_params(self, connect, ip_collection, get_index): |
||
851 | ''' |
||
852 | target: search collection, and check the result: distance |
||
853 | method: compare the return distance value with value computed with Inner product |
||
854 | expected: the return distance equals to the computed value |
||
855 | ''' |
||
856 | top_k = 2 |
||
857 | nprobe = 1 |
||
858 | index_param = get_index["index_param"] |
||
859 | index_type = get_index["index_type"] |
||
860 | if index_type == IndexType.RNSG: |
||
861 | pytest.skip("rnsg not support in ip") |
||
862 | vectors, ids = self.init_data(connect, ip_collection, nb=2) |
||
863 | connect.create_index(ip_collection, index_type, index_param) |
||
864 | logging.getLogger().info(connect.get_index_info(ip_collection)) |
||
865 | query_vecs = [[0.50 for i in range(dim)]] |
||
866 | search_param = get_search_param(index_type) |
||
867 | status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param) |
||
868 | logging.getLogger().debug(status) |
||
869 | logging.getLogger().debug(result) |
||
870 | distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0])) |
||
871 | distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1])) |
||
872 | assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance) |
||
873 | |||
874 | # def test_search_concurrent(self, connect, collection): |
||
875 | # vectors, ids = self.init_data(connect, collection, nb=5000) |
||
876 | # thread_num = 50 |
||
877 | # nq = 1 |
||
878 | # top_k = 2 |
||
879 | # threads = [] |
||
880 | # query_vecs = vectors[:nq] |
||
881 | # def search(thread_number): |
||
882 | # for i in range(1000000): |
||
883 | # status, result = connect.search(collection, top_k, query_vecs, timeout=2) |
||
884 | # assert len(result) == len(query_vecs) |
||
885 | # assert status.OK() |
||
886 | # if i % 1000 == 0: |
||
887 | # logging.getLogger().info("In %d, %d" % (thread_number, i)) |
||
888 | # logging.getLogger().info("%d finished" % thread_number) |
||
889 | # # with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor: |
||
890 | # # future_results = {executor.submit( |
||
891 | # # search): i for i in range(1000000)} |
||
892 | # # for future in concurrent.futures.as_completed(future_results): |
||
893 | # # future.result() |
||
894 | # for i in range(thread_num): |
||
895 | # t = threading.Thread(target=search, args=(i, )) |
||
896 | # threads.append(t) |
||
897 | # t.start() |
||
898 | # for t in threads: |
||
899 | # t.join() |
||
900 | |||
901 | View Code Duplication | @pytest.mark.level(2) |
|
902 | @pytest.mark.timeout(30) |
||
903 | def test_search_concurrent_multithreads(self, args): |
||
904 | ''' |
||
905 | target: test concurrent search with multiprocessess |
||
906 | method: search with 10 processes, each process uses dependent connection |
||
907 | expected: status ok and the returned vectors should be query_records |
||
908 | ''' |
||
909 | nb = 100 |
||
910 | top_k = 10 |
||
911 | threads_num = 4 |
||
912 | threads = [] |
||
913 | collection = gen_unique_str("test_search_concurrent_multiprocessing") |
||
914 | uri = "tcp://%s:%s" % (args["ip"], args["port"]) |
||
915 | param = {'collection_name': collection, |
||
916 | 'dimension': dim, |
||
917 | 'index_type': IndexType.FLAT, |
||
918 | 'store_raw_vector': False} |
||
919 | # create collection |
||
920 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
921 | milvus.create_collection(param) |
||
922 | vectors, ids = self.init_data(milvus, collection, nb=nb) |
||
923 | query_vecs = vectors[nb // 2:nb] |
||
924 | |||
925 | def search(milvus): |
||
926 | status, result = milvus.search(collection, top_k, query_vecs) |
||
927 | assert len(result) == len(query_vecs) |
||
928 | for i in range(len(query_vecs)): |
||
929 | assert result[i][0].id in ids |
||
930 | assert result[i][0].distance == 0.0 |
||
931 | |||
932 | for i in range(threads_num): |
||
933 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
934 | t = threading.Thread(target=search, args=(milvus,)) |
||
935 | threads.append(t) |
||
936 | t.start() |
||
937 | time.sleep(0.2) |
||
938 | for t in threads: |
||
939 | t.join() |
||
940 | |||
941 | # TODO: enable |
||
942 | View Code Duplication | @pytest.mark.timeout(30) |
|
943 | def _test_search_concurrent_multiprocessing(self, args): |
||
944 | ''' |
||
945 | target: test concurrent search with multiprocessess |
||
946 | method: search with 10 processes, each process uses dependent connection |
||
947 | expected: status ok and the returned vectors should be query_records |
||
948 | ''' |
||
949 | nb = 100 |
||
950 | top_k = 10 |
||
951 | process_num = 4 |
||
952 | processes = [] |
||
953 | collection = gen_unique_str("test_search_concurrent_multiprocessing") |
||
954 | uri = "tcp://%s:%s" % (args["ip"], args["port"]) |
||
955 | param = {'collection_name': collection, |
||
956 | 'dimension': dim, |
||
957 | 'index_type': IndexType.FLAT, |
||
958 | 'store_raw_vector': False} |
||
959 | # create collection |
||
960 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
961 | milvus.create_collection(param) |
||
962 | vectors, ids = self.init_data(milvus, collection, nb=nb) |
||
963 | query_vecs = vectors[nb // 2:nb] |
||
964 | |||
965 | def search(milvus): |
||
966 | status, result = milvus.search(collection, top_k, query_vecs) |
||
967 | assert len(result) == len(query_vecs) |
||
968 | for i in range(len(query_vecs)): |
||
969 | assert result[i][0].id in ids |
||
970 | assert result[i][0].distance == 0.0 |
||
971 | |||
972 | for i in range(process_num): |
||
973 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
974 | p = Process(target=search, args=(milvus,)) |
||
975 | processes.append(p) |
||
976 | p.start() |
||
977 | time.sleep(0.2) |
||
978 | for p in processes: |
||
979 | p.join() |
||
980 | |||
981 | View Code Duplication | def test_search_multi_collection_L2(search, args): |
|
982 | ''' |
||
983 | target: test search multi collections of L2 |
||
984 | method: add vectors into 10 collections, and search |
||
985 | expected: search status ok, the length of result |
||
986 | ''' |
||
987 | num = 10 |
||
988 | top_k = 10 |
||
989 | collections = [] |
||
990 | idx = [] |
||
991 | for i in range(num): |
||
992 | collection = gen_unique_str("test_add_multicollection_%d" % i) |
||
993 | uri = "tcp://%s:%s" % (args["ip"], args["port"]) |
||
994 | param = {'collection_name': collection, |
||
995 | 'dimension': dim, |
||
996 | 'index_file_size': 10, |
||
997 | 'metric_type': MetricType.L2} |
||
998 | # create collection |
||
999 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
1000 | milvus.create_collection(param) |
||
1001 | status, ids = milvus.insert(collection, vectors) |
||
1002 | assert status.OK() |
||
1003 | assert len(ids) == len(vectors) |
||
1004 | collections.append(collection) |
||
1005 | idx.append(ids[0]) |
||
1006 | idx.append(ids[10]) |
||
1007 | idx.append(ids[20]) |
||
1008 | milvus.flush([collection]) |
||
1009 | query_vecs = [vectors[0], vectors[10], vectors[20]] |
||
1010 | # start query from random collection |
||
1011 | for i in range(num): |
||
1012 | collection = collections[i] |
||
1013 | status, result = milvus.search(collection, top_k, query_vecs) |
||
1014 | assert status.OK() |
||
1015 | assert len(result) == len(query_vecs) |
||
1016 | for j in range(len(query_vecs)): |
||
1017 | assert len(result[j]) == top_k |
||
1018 | for j in range(len(query_vecs)): |
||
1019 | assert check_result(result[j], idx[3 * i + j]) |
||
1020 | |||
1021 | View Code Duplication | def test_search_multi_collection_IP(search, args): |
|
1022 | ''' |
||
1023 | target: test search multi collections of IP |
||
1024 | method: add vectors into 10 collections, and search |
||
1025 | expected: search status ok, the length of result |
||
1026 | ''' |
||
1027 | num = 10 |
||
1028 | top_k = 10 |
||
1029 | collections = [] |
||
1030 | idx = [] |
||
1031 | for i in range(num): |
||
1032 | collection = gen_unique_str("test_add_multicollection_%d" % i) |
||
1033 | uri = "tcp://%s:%s" % (args["ip"], args["port"]) |
||
1034 | param = {'collection_name': collection, |
||
1035 | 'dimension': dim, |
||
1036 | 'index_file_size': 10, |
||
1037 | 'metric_type': MetricType.L2} |
||
1038 | # create collection |
||
1039 | milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) |
||
1040 | milvus.create_collection(param) |
||
1041 | status, ids = milvus.insert(collection, vectors) |
||
1042 | assert status.OK() |
||
1043 | assert len(ids) == len(vectors) |
||
1044 | collections.append(collection) |
||
1045 | idx.append(ids[0]) |
||
1046 | idx.append(ids[10]) |
||
1047 | idx.append(ids[20]) |
||
1048 | milvus.flush([collection]) |
||
1049 | query_vecs = [vectors[0], vectors[10], vectors[20]] |
||
1050 | # start query from random collection |
||
1051 | for i in range(num): |
||
1052 | collection = collections[i] |
||
1053 | status, result = milvus.search(collection, top_k, query_vecs) |
||
1054 | assert status.OK() |
||
1055 | assert len(result) == len(query_vecs) |
||
1056 | for j in range(len(query_vecs)): |
||
1057 | assert len(result[j]) == top_k |
||
1058 | for j in range(len(query_vecs)): |
||
1059 | assert check_result(result[j], idx[3 * i + j]) |
||
1060 | |||
1061 | @pytest.fixture(params=MetricType) |
||
1062 | def get_binary_metric_types(self, request): |
||
1063 | if request.param == MetricType.INVALID: |
||
1064 | pytest.skip(("metric type invalid")) |
||
1065 | if request.param in [MetricType.L2, MetricType.IP]: |
||
1066 | pytest.skip(("L2 and IP not support in binary")) |
||
1067 | return request.param |
||
1068 | |||
1069 | # 4678 and # 4683 |
||
1070 | def test_search_binary_dim_not_power_of_2(self, connect, get_binary_metric_types): |
||
1071 | metric = get_binary_metric_types |
||
1072 | collection = gen_unique_str(collection_id) |
||
1073 | dim = 200 |
||
1074 | top_k = 1 |
||
1075 | param = {'collection_name': collection, |
||
1076 | 'dimension': dim, |
||
1077 | 'index_file_size': 10, |
||
1078 | 'metric_type': metric} |
||
1079 | status = connect.create_collection(param) |
||
1080 | assert status.OK() |
||
1081 | int_vectors, vectors, ids = self.init_binary_data(connect, collection, nb=1000, dim=dim) |
||
1082 | search_param = get_search_param(IndexType.FLAT) |
||
1083 | status, result = connect.search(collection, top_k, vectors[:1], params=search_param) |
||
1084 | assert status.OK() |
||
1085 | assert result[0][0].id in ids |
||
1086 | assert result[0][0].distance == 0.0 |
||
1087 | |||
1088 | @pytest.fixture(params=MetricType) |
||
1089 | def get_metric_types(self, request): |
||
1090 | if request.param == MetricType.INVALID: |
||
1091 | pytest.skip(("metric type invalid")) |
||
1092 | if request.param not in [MetricType.L2, MetricType.IP]: |
||
1093 | pytest.skip(("L2 and IP not support in binary")) |
||
1094 | return request.param |
||
1095 | |||
1096 | def test_search_float_dim_not_power_of_2(self, connect, get_metric_types): |
||
1097 | metric = get_metric_types |
||
1098 | collection = gen_unique_str(collection_id) |
||
1099 | dim = 200 |
||
1100 | top_k = 1 |
||
1101 | param = {'collection_name': collection, |
||
1102 | 'dimension': dim, |
||
1103 | 'index_file_size': 10, |
||
1104 | 'metric_type': metric} |
||
1105 | status = connect.create_collection(param) |
||
1106 | assert status.OK() |
||
1107 | vectors, ids = self.init_data(connect, collection, nb=1000, dim=dim) |
||
1108 | search_param = get_search_param(IndexType.FLAT) |
||
1109 | status, result = connect.search(collection, top_k, vectors[:1], params=search_param) |
||
1110 | assert status.OK() |
||
1111 | assert result[0][0].id in ids |
||
1112 | |||
1113 | """ |
||
1114 | ****************************************************************** |
||
1115 | # The following cases are used to test `search_vectors` function |
||
1116 | # with invalid collection_name top-k / nprobe / query_range |
||
1117 | ****************************************************************** |
||
1118 | """ |
||
1119 | |||
1120 | |||
1121 | class TestSearchParamsInvalid(object): |
||
1122 | nlist = 16384 |
||
1123 | index_type = IndexType.IVF_SQ8 |
||
1124 | index_param = {"nlist": nlist} |
||
1125 | logging.getLogger().info(index_param) |
||
1126 | |||
1127 | def init_data(self, connect, collection, nb=6000): |
||
1128 | ''' |
||
1129 | Generate vectors and add it in collection, before search vectors |
||
1130 | ''' |
||
1131 | global vectors |
||
1132 | if nb == 6000: |
||
1133 | insert = vectors |
||
1134 | else: |
||
1135 | insert = gen_vectors(nb, dim) |
||
1136 | status, ids = connect.insert(collection, insert) |
||
1137 | connect.flush([collection]) |
||
1138 | return insert, ids |
||
1139 | |||
1140 | """ |
||
1141 | Test search collection with invalid collection names |
||
1142 | """ |
||
1143 | |||
1144 | @pytest.fixture( |
||
1145 | scope="function", |
||
1146 | params=gen_invalid_collection_names() |
||
1147 | ) |
||
1148 | def get_collection_name(self, request): |
||
1149 | yield request.param |
||
1150 | |||
1151 | @pytest.mark.level(2) |
||
1152 | def test_search_with_invalid_collectionname(self, connect, get_collection_name): |
||
1153 | collection_name = get_collection_name |
||
1154 | logging.getLogger().info(collection_name) |
||
1155 | nprobe = 1 |
||
1156 | query_vecs = gen_vectors(1, dim) |
||
1157 | status, result = connect.search(collection_name, top_k, query_vecs) |
||
1158 | assert not status.OK() |
||
1159 | |||
1160 | @pytest.mark.level(1) |
||
1161 | def test_search_with_invalid_tag_format(self, connect, collection): |
||
1162 | nprobe = 1 |
||
1163 | query_vecs = gen_vectors(1, dim) |
||
1164 | with pytest.raises(Exception) as e: |
||
1165 | status, result = connect.search(collection, top_k, query_vecs, partition_tags="tag") |
||
1166 | logging.getLogger().debug(result) |
||
1167 | |||
1168 | @pytest.mark.level(1) |
||
1169 | def test_search_with_tag_not_existed(self, connect, collection): |
||
1170 | nprobe = 1 |
||
1171 | query_vecs = gen_vectors(1, dim) |
||
1172 | status, result = connect.search(collection, top_k, query_vecs, partition_tags=["tag"]) |
||
1173 | logging.getLogger().info(result) |
||
1174 | assert not status.OK() |
||
1175 | |||
1176 | """ |
||
1177 | Test search collection with invalid top-k |
||
1178 | """ |
||
1179 | |||
1180 | @pytest.fixture( |
||
1181 | scope="function", |
||
1182 | params=gen_invalid_top_ks() |
||
1183 | ) |
||
1184 | def get_top_k(self, request): |
||
1185 | yield request.param |
||
1186 | |||
1187 | View Code Duplication | @pytest.mark.level(1) |
|
1188 | def test_search_with_invalid_top_k(self, connect, collection, get_top_k): |
||
1189 | ''' |
||
1190 | target: test search fuction, with the wrong top_k |
||
1191 | method: search with top_k |
||
1192 | expected: raise an error, and the connection is normal |
||
1193 | ''' |
||
1194 | top_k = get_top_k |
||
1195 | logging.getLogger().info(top_k) |
||
1196 | nprobe = 1 |
||
1197 | query_vecs = gen_vectors(1, dim) |
||
1198 | if isinstance(top_k, int): |
||
1199 | status, result = connect.search(collection, top_k, query_vecs) |
||
1200 | assert not status.OK() |
||
1201 | else: |
||
1202 | with pytest.raises(Exception) as e: |
||
1203 | status, result = connect.search(collection, top_k, query_vecs) |
||
1204 | |||
1205 | View Code Duplication | @pytest.mark.level(2) |
|
1206 | def test_search_with_invalid_top_k_ip(self, connect, ip_collection, get_top_k): |
||
1207 | ''' |
||
1208 | target: test search fuction, with the wrong top_k |
||
1209 | method: search with top_k |
||
1210 | expected: raise an error, and the connection is normal |
||
1211 | ''' |
||
1212 | top_k = get_top_k |
||
1213 | logging.getLogger().info(top_k) |
||
1214 | nprobe = 1 |
||
1215 | query_vecs = gen_vectors(1, dim) |
||
1216 | if isinstance(top_k, int): |
||
1217 | status, result = connect.search(ip_collection, top_k, query_vecs) |
||
1218 | assert not status.OK() |
||
1219 | else: |
||
1220 | with pytest.raises(Exception) as e: |
||
1221 | status, result = connect.search(ip_collection, top_k, query_vecs) |
||
1222 | |||
1223 | """ |
||
1224 | Test search collection with invalid nprobe |
||
1225 | """ |
||
1226 | |||
1227 | @pytest.fixture( |
||
1228 | scope="function", |
||
1229 | params=gen_invalid_nprobes() |
||
1230 | ) |
||
1231 | def get_nprobes(self, request): |
||
1232 | yield request.param |
||
1233 | |||
1234 | View Code Duplication | @pytest.mark.level(1) |
|
1235 | def test_search_with_invalid_nprobe(self, connect, collection, get_nprobes): |
||
1236 | ''' |
||
1237 | target: test search fuction, with the wrong nprobe |
||
1238 | method: search with nprobe |
||
1239 | expected: raise an error, and the connection is normal |
||
1240 | ''' |
||
1241 | index_type = IndexType.IVF_SQ8 |
||
1242 | index_param = {"nlist": 16384} |
||
1243 | connect.create_index(collection, index_type, index_param) |
||
1244 | nprobe = get_nprobes |
||
1245 | search_param = {"nprobe": nprobe} |
||
1246 | logging.getLogger().info(nprobe) |
||
1247 | query_vecs = gen_vectors(1, dim) |
||
1248 | # if isinstance(nprobe, int): |
||
1249 | status, result = connect.search(collection, top_k, query_vecs, params=search_param) |
||
1250 | assert not status.OK() |
||
1251 | # else: |
||
1252 | # with pytest.raises(Exception) as e: |
||
1253 | # status, result = connect.search(collection, top_k, query_vecs, params=search_param) |
||
1254 | |||
1255 | View Code Duplication | @pytest.mark.level(2) |
|
1256 | def test_search_with_invalid_nprobe_ip(self, connect, ip_collection, get_nprobes): |
||
1257 | ''' |
||
1258 | target: test search fuction, with the wrong top_k |
||
1259 | method: search with top_k |
||
1260 | expected: raise an error, and the connection is normal |
||
1261 | ''' |
||
1262 | index_type = IndexType.IVF_SQ8 |
||
1263 | index_param = {"nlist": 16384} |
||
1264 | connect.create_index(ip_collection, index_type, index_param) |
||
1265 | nprobe = get_nprobes |
||
1266 | search_param = {"nprobe": nprobe} |
||
1267 | logging.getLogger().info(nprobe) |
||
1268 | query_vecs = gen_vectors(1, dim) |
||
1269 | |||
1270 | # if isinstance(nprobe, int): |
||
1271 | status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param) |
||
1272 | assert not status.OK() |
||
1273 | # else: |
||
1274 | # with pytest.raises(Exception) as e: |
||
1275 | # status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param) |
||
1276 | |||
1277 | def test_search_with_2049_nprobe(self, connect, collection): |
||
1278 | ''' |
||
1279 | target: test search function, with 2049 nprobe in GPU mode |
||
1280 | method: search with nprobe |
||
1281 | expected: status not ok |
||
1282 | ''' |
||
1283 | if str(connect._cmd("mode")[1]) == "CPU": |
||
1284 | pytest.skip("Only support GPU mode") |
||
1285 | for index in gen_simple_index(): |
||
1286 | if index["index_type"] in [IndexType.IVF_PQ, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]: |
||
1287 | index_type = index["index_type"] |
||
1288 | index_param = index["index_param"] |
||
1289 | self.init_data(connect, collection) |
||
1290 | connect.create_index(collection, index_type, index_param) |
||
1291 | nprobe = 2049 |
||
1292 | search_param = {"nprobe": nprobe} |
||
1293 | query_vecs = gen_vectors(nprobe, dim) |
||
1294 | status, result = connect.search(collection, top_k, query_vecs, params=search_param) |
||
1295 | assert status.OK() |
||
1296 | |||
1297 | View Code Duplication | @pytest.fixture( |
|
1298 | scope="function", |
||
1299 | params=gen_simple_index() |
||
1300 | ) |
||
1301 | def get_simple_index(self, request, connect): |
||
1302 | if str(connect._cmd("mode")[1]) == "CPU": |
||
1303 | if request.param["index_type"] == IndexType.IVF_SQ8H: |
||
1304 | pytest.skip("sq8h not support in CPU mode") |
||
1305 | if str(connect._cmd("mode")[1]) == "GPU": |
||
1306 | if request.param["index_type"] == IndexType.IVF_PQ: |
||
1307 | pytest.skip("ivfpq not support in GPU mode") |
||
1308 | return request.param |
||
1309 | |||
1310 | def test_search_with_empty_params(self, connect, collection, args, get_simple_index): |
||
1311 | ''' |
||
1312 | target: test search fuction, with empty search params |
||
1313 | method: search with params |
||
1314 | expected: search status not ok, and the connection is normal |
||
1315 | ''' |
||
1316 | if args["handler"] == "HTTP": |
||
1317 | pytest.skip("skip in http mode") |
||
1318 | index_type = get_simple_index["index_type"] |
||
1319 | index_param = get_simple_index["index_param"] |
||
1320 | connect.create_index(collection, index_type, index_param) |
||
1321 | query_vecs = gen_vectors(1, dim) |
||
1322 | status, result = connect.search(collection, top_k, query_vecs, params={}) |
||
1323 | |||
1324 | if index_type == IndexType.FLAT: |
||
1325 | assert status.OK() |
||
1326 | else: |
||
1327 | assert not status.OK() |
||
1328 | |||
1329 | @pytest.fixture( |
||
1330 | scope="function", |
||
1331 | params=gen_invaild_search_params() |
||
1332 | ) |
||
1333 | def get_invalid_search_param(self, request, connect): |
||
1334 | if str(connect._cmd("mode")[1]) == "CPU": |
||
1335 | if request.param["index_type"] == IndexType.IVF_SQ8H: |
||
1336 | pytest.skip("sq8h not support in CPU mode") |
||
1337 | if str(connect._cmd("mode")[1]) == "GPU": |
||
1338 | if request.param["index_type"] == IndexType.IVF_PQ: |
||
1339 | pytest.skip("ivfpq not support in GPU mode") |
||
1340 | return request.param |
||
1341 | |||
1342 | def test_search_with_invalid_params(self, connect, collection, get_invalid_search_param): |
||
1343 | ''' |
||
1344 | target: test search fuction, with invalid search params |
||
1345 | method: search with params |
||
1346 | expected: search status not ok, and the connection is normal |
||
1347 | ''' |
||
1348 | index_type = get_invalid_search_param["index_type"] |
||
1349 | search_param = get_invalid_search_param["search_param"] |
||
1350 | for index in gen_simple_index(): |
||
1351 | if index_type == index["index_type"]: |
||
1352 | connect.create_index(collection, index_type, index["index_param"]) |
||
1353 | query_vecs = gen_vectors(1, dim) |
||
1354 | status, result = connect.search(collection, top_k, query_vecs, params=search_param) |
||
1355 | assert not status.OK() |
||
1356 | |||
1357 | |||
1358 | def check_result(result, id): |
||
1359 | if len(result) >= 5: |
||
1360 | return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id] |
||
1361 | else: |
||
1362 | return id in (i.id for i in result) |
||
1363 |