TestSearchBase.test_search_ip_large_nq_index_params()   A
last analyzed

Complexity

Conditions 3

Size

Total Lines 24
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 19
nop 4
dl 0
loc 24
rs 9.45
c 0
b 0
f 0
1
import pdb
2
import struct
3
from random import sample
4
import threading
5
import datetime
6
import logging
7
from time import sleep
8
import concurrent.futures
9
from multiprocessing import Process
10
import pytest
11
import numpy
12
import sklearn.preprocessing
13
from milvus import IndexType, MetricType
14
from utils import *
15
16
dim = 128
17
collection_id = "test_search"
18
add_interval_time = 2
19
vectors = gen_vectors(6000, dim)
20
vectors = sklearn.preprocessing.normalize(vectors, axis=1, norm='l2')
21
vectors = vectors.tolist()
22
top_k = 1
23
nprobe = 1
24
epsilon = 0.001
25
tag = "1970-01-01"
26
raw_vectors, binary_vectors = gen_binary_vectors(6000, dim)
27
28
29
class TestSearchBase:
30
    def init_data(self, connect, collection, nb=6000, dim=dim, partition_tags=None):
31
        '''
32
        Generate vectors and add it in collection, before search vectors
33
        '''
34
        global vectors
35
        if nb == 6000:
36
            add_vectors = vectors
37
        else:
38
            add_vectors = gen_vectors(nb, dim)
39
            add_vectors = sklearn.preprocessing.normalize(add_vectors, axis=1, norm='l2')
40
            add_vectors = add_vectors.tolist()
41
        if partition_tags is None:
42
            status, ids = connect.insert(collection, add_vectors)
43
            assert status.OK()
44
        else:
45
            status, ids = connect.insert(collection, add_vectors, partition_tag=partition_tags)
46
            assert status.OK()
47
        connect.flush([collection])
48
        return add_vectors, ids
49
50
    def init_binary_data(self, connect, collection, nb=6000, dim=dim, insert=True, partition_tags=None):
51
        '''
52
        Generate vectors and add it in collection, before search vectors
53
        '''
54
        ids = []
55
        global binary_vectors
56
        global raw_vectors
57
        if nb == 6000:
58
            add_vectors = binary_vectors
59
            add_raw_vectors = raw_vectors
60
        else:
61
            add_raw_vectors, add_vectors = gen_binary_vectors(nb, dim)
62
        if insert is True:
63
            if partition_tags is None:
64
                status, ids = connect.insert(collection, add_vectors)
65
                assert status.OK()
66
            else:
67
                status, ids = connect.insert(collection, add_vectors, partition_tag=partition_tags)
68
                assert status.OK()
69
            connect.flush([collection])
70
        return add_raw_vectors, add_vectors, ids
71
72
    """
73
    generate valid create_index params
74
    """
75
76
    @pytest.fixture(
77
        scope="function",
78
        params=gen_index()
79
    )
80
    def get_index(self, request, connect):
81
        if str(connect._cmd("mode")[1]) == "CPU":
82
            if request.param["index_type"] == IndexType.IVF_SQ8H:
83
                pytest.skip("sq8h not support in CPU mode")
84
        return request.param
85
86
    @pytest.fixture(
87
        scope="function",
88
        params=gen_simple_index()
89
    )
90
    def get_simple_index(self, request, connect):
91
        if str(connect._cmd("mode")[1]) == "CPU":
92
            if request.param["index_type"] == IndexType.IVF_SQ8H:
93
                pytest.skip("sq8h not support in CPU mode")
94
        return request.param
95
96 View Code Duplication
    @pytest.fixture(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
97
        scope="function",
98
        params=gen_simple_index()
99
    )
100
    def get_jaccard_index(self, request, connect):
101
        logging.getLogger().info(request.param)
102
        if request.param["index_type"] == IndexType.IVFLAT or request.param["index_type"] == IndexType.FLAT:
103
            return request.param
104
        else:
105
            pytest.skip("Skip index Temporary")
106
107 View Code Duplication
    @pytest.fixture(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
108
        scope="function",
109
        params=gen_simple_index()
110
    )
111
    def get_hamming_index(self, request, connect):
112
        logging.getLogger().info(request.param)
113
        if request.param["index_type"] == IndexType.IVFLAT or request.param["index_type"] == IndexType.FLAT:
114
            return request.param
115
        else:
116
            pytest.skip("Skip index Temporary")
117
118
    @pytest.fixture(
119
        scope="function",
120
        params=gen_simple_index()
121
    )
122
    def get_structure_index(self, request, connect):
123
        logging.getLogger().info(request.param)
124
        if request.param["index_type"] == IndexType.FLAT:
125
            return request.param
126
        else:
127
            pytest.skip("Skip index Temporary")
128
129
    """
130
    generate top-k params
131
    """
132
133
    @pytest.fixture(
134
        scope="function",
135
        params=[1, 99, 1024, 2049, 16385]
136
    )
137
    def get_top_k(self, request):
138
        yield request.param
139
140
    def test_search_top_k_flat_index(self, connect, collection, get_top_k):
141
        '''
142
        target: test basic search fuction, all the search params is corrent, change top-k value
143
        method: search with the given vectors, check the result
144
        expected: search status ok, and the length of the result is top_k
145
        '''
146
        vectors, ids = self.init_data(connect, collection)
147
        query_vec = [vectors[0]]
148
        top_k = get_top_k
149
        status, result = connect.search(collection, top_k, query_vec)
150
        if top_k <= 16384:
151
            assert status.OK()
152
            assert len(result[0]) == min(len(vectors), top_k)
153
            assert result[0][0].distance <= epsilon
154
            assert check_result(result[0], ids[0])
155
        else:
156
            assert not status.OK()
157
158
    def test_search_top_k_flat_index_metric_type(self, connect, collection):
159
        '''
160
        target: test basic search fuction, all the search params is corrent, change top-k value
161
        method: search with the given vectors, check the result
162
        expected: search status ok, and the length of the result is top_k
163
        '''
164
        vectors, ids = self.init_data(connect, collection)
165
        query_vec = [vectors[0]]
166
        status, result = connect.search(collection, top_k, query_vec, params={"metric_type": MetricType.IP.value})
167
        assert status.OK()
168
        assert len(result[0]) == min(len(vectors), top_k)
169
        assert result[0][0].distance >= 1 - epsilon
170
        assert check_result(result[0], ids[0])
171
172
    @pytest.mark.level(2)
173
    def test_search_top_k_flat_index_metric_type_invalid(self, connect, collection):
174
        '''
175
        target: test basic search fuction, all the search params is corrent, change top-k value
176
        method: search with the given vectors, check the result
177
        expected: search status ok, and the length of the result is top_k
178
        '''
179
        vectors, ids = self.init_data(connect, collection)
180
        query_vec = [vectors[0]]
181
        status, result = connect.search(collection, top_k, query_vec, params={"metric_type": MetricType.JACCARD.value})
182
        assert not status.OK()
183
184
    def test_search_l2_index_params(self, connect, collection, get_simple_index):
185
        '''
186
        target: test basic search fuction, all the search params is corrent, test all index params, and build
187
        method: search with the given vectors, check the result
188
        expected: search status ok, and the length of the result is top_k
189
        '''
190
        top_k = 10
191
        index_param = get_simple_index["index_param"]
192
        index_type = get_simple_index["index_type"]
193
        logging.getLogger().info(get_simple_index)
194
        vectors, ids = self.init_data(connect, collection)
195
        status = connect.create_index(collection, index_type, index_param)
196
        query_vec = [vectors[0], vectors[1]]
197
        search_param = get_search_param(index_type)
198
        status, result = connect.search(collection, top_k, query_vec, params=search_param)
199
        logging.getLogger().info(result)
200
        if top_k <= 1024:
201
            assert status.OK()
202
            assert len(result[0]) == min(len(vectors), top_k)
203
            assert check_result(result[0], ids[0])
204
            assert result[0][0].distance < result[0][1].distance
205
            assert result[1][0].distance < result[1][1].distance
206
        else:
207
            assert not status.OK()
208
209
    def test_search_l2_large_nq_index_params(self, connect, collection, get_simple_index):
210
        '''
211
        target: test basic search fuction, all the search params is corrent, test all index params, and build
212
        method: search with the given vectors, check the result
213
        expected: search status ok, and the length of the result is top_k
214
        '''
215
        top_k = 10
216
        index_param = get_simple_index["index_param"]
217
        index_type = get_simple_index["index_type"]
218
        logging.getLogger().info(get_simple_index)
219
        if index_type == IndexType.IVF_PQ:
220
            pytest.skip("Skip PQ")
221
222
        vectors, ids = self.init_data(connect, collection)
223
        status = connect.create_index(collection, index_type, index_param)
224
        query_vec = vectors[:1000]
225
        search_param = get_search_param(index_type)
226
        status, result = connect.search(collection, top_k, query_vec, params=search_param)
227
        logging.getLogger().info(result)
228
        assert status.OK()
229
        assert len(result[0]) == min(len(vectors), top_k)
230
        assert check_result(result[0], ids[0])
231
        assert result[0][0].distance <= epsilon
232
233
    def test_search_with_multi_partitions(self, connect, collection):
234
        '''
235
        target: test search with multi partition which contains default tag and other tags
236
        method: insert vectors into e partition and search with partitions [_default, tag]
237
        expected: search result is correct
238
        '''
239
        connect.create_partition(collection, tag)
240
        vectors, ids = self.init_data(connect, collection, nb=10, partition_tags=tag)
241
        query_vec = [vectors[0]]
242
        search_param = get_search_param(IndexType.FLAT)
243
        status, result = connect.search(collection, top_k, query_vec, partition_tags=["_default", tag],
244
                                        params=search_param)
245
        assert status.OK()
246
        logging.getLogger().info(result)
247
        assert len(result[0]) == min(len(vectors), top_k)
248
        assert check_result(result[0], ids[0])
249
        assert result[0][0].distance <= epsilon
250
251
    def test_search_l2_index_params_partition(self, connect, collection, get_simple_index):
252
        '''
253
        target: test basic search fuction, all the search params is corrent, test all index params, and build
254
        method: add vectors into collection, search with the given vectors, check the result
255
        expected: search status ok, and the length of the result is top_k, search collection with partition tag return empty
256
        '''
257
        top_k = 10
258
        index_param = get_simple_index["index_param"]
259
        index_type = get_simple_index["index_type"]
260
        logging.getLogger().info(get_simple_index)
261
        if index_type == IndexType.IVF_PQ:
262
            pytest.skip("Skip PQ")
263
        status = connect.create_partition(collection, tag)
264
        vectors, ids = self.init_data(connect, collection)
265
        status = connect.create_index(collection, index_type, index_param)
266
        query_vec = [vectors[0]]
267
        search_param = get_search_param(index_type)
268
        status, result = connect.search(collection, top_k, query_vec, params=search_param)
269
        logging.getLogger().info(result)
270
        assert status.OK()
271
        assert len(result[0]) == min(len(vectors), top_k)
272
        assert check_result(result[0], ids[0])
273
        assert result[0][0].distance <= epsilon
274
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag], params=search_param)
275
        logging.getLogger().info(result)
276
        assert status.OK()
277
        assert len(result) == 0
278
279
    def test_search_l2_index_params_partition_A(self, connect, collection, get_simple_index):
280
        '''
281
        target: test basic search fuction, all the search params is corrent, test all index params, and build
282
        method: search partition with the given vectors, check the result
283
        expected: search status ok, and the length of the result is 0
284
        '''
285
        top_k = 10
286
        index_param = get_simple_index["index_param"]
287
        index_type = get_simple_index["index_type"]
288
        logging.getLogger().info(get_simple_index)
289
        if index_type == IndexType.IVF_PQ:
290
            pytest.skip("Skip PQ")
291
292
        status = connect.create_partition(collection, tag)
293
        vectors, ids = self.init_data(connect, collection)
294
        status = connect.create_index(collection, index_type, index_param)
295
        query_vec = [vectors[0]]
296
        search_param = get_search_param(index_type)
297
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag], params=search_param)
298
        logging.getLogger().info(result)
299
        assert status.OK()
300
        assert len(result) == 0
301
302
    def test_search_l2_index_params_partition_B(self, connect, collection, get_simple_index):
303
        '''
304
        target: test basic search fuction, all the search params is corrent, test all index params, and build
305
        method: search with the given vectors, check the result
306
        expected: search status ok, and the length of the result is top_k
307
        '''
308
        top_k = 10
309
        index_param = get_simple_index["index_param"]
310
        index_type = get_simple_index["index_type"]
311
        logging.getLogger().info(get_simple_index)
312
        if index_type == IndexType.IVF_PQ:
313
            pytest.skip("Skip PQ")
314
        status = connect.create_partition(collection, tag)
315
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
316
        status = connect.create_index(collection, index_type, index_param)
317
        query_vec = [vectors[0]]
318
        search_param = get_search_param(index_type)
319
        status, result = connect.search(collection, top_k, query_vec, params=search_param)
320
        logging.getLogger().info(result)
321
        assert status.OK()
322
        assert len(result[0]) == min(len(vectors), top_k)
323
        assert check_result(result[0], ids[0])
324
        assert result[0][0].distance <= epsilon
325
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag], params=search_param)
326
        logging.getLogger().info(result)
327
        assert status.OK()
328
        assert len(result[0]) == min(len(vectors), top_k)
329
        assert check_result(result[0], ids[0])
330
        assert result[0][0].distance <= epsilon
331
332
    def test_search_l2_index_params_partition_C(self, connect, collection, get_simple_index):
333
        '''
334
        target: test basic search fuction, all the search params is corrent, test all index params, and build
335
        method: search with the given vectors and tags (one of the tags not existed in collection), check the result
336
        expected: search status ok, and the length of the result is top_k
337
        '''
338
        index_param = get_simple_index["index_param"]
339
        index_type = get_simple_index["index_type"]
340
        logging.getLogger().info(get_simple_index)
341
        if index_type == IndexType.IVF_PQ:
342
            pytest.skip("Skip PQ")
343
        status = connect.create_partition(collection, tag)
344
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
345
        status = connect.create_index(collection, index_type, index_param)
346
        query_vec = [vectors[0]]
347
        top_k = 10
348
        search_param = get_search_param(index_type)
349
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, "new_tag"],
350
                                        params=search_param)
351
        logging.getLogger().info(result)
352
        assert status.OK()
353
        assert len(result[0]) == min(len(vectors), top_k)
354
        assert check_result(result[0], ids[0])
355
        assert result[0][0].distance <= epsilon
356
357
    @pytest.mark.level(2)
358
    def test_search_l2_index_params_partition_D(self, connect, collection, get_simple_index):
359
        '''
360
        target: test basic search fuction, all the search params is corrent, test all index params, and build
361
        method: search with the given vectors and tag (tag name not existed in collection), check the result
362
        expected: search status ok, and the length of the result is top_k
363
        '''
364
        index_param = get_simple_index["index_param"]
365
        index_type = get_simple_index["index_type"]
366
        logging.getLogger().info(get_simple_index)
367
        status = connect.create_partition(collection, tag)
368
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
369
        status = connect.create_index(collection, index_type, index_param)
370
        query_vec = [vectors[0]]
371
        top_k = 10
372
        search_param = get_search_param(index_type)
373
        status, result = connect.search(collection, top_k, query_vec, partition_tags=["new_tag"], params=search_param)
374
        logging.getLogger().info(result)
375
        assert not status.OK()
376
377
    @pytest.mark.level(2)
378
    def test_search_l2_index_params_partition_E(self, connect, collection, get_simple_index):
379
        '''
380
        target: test basic search fuction, all the search params is corrent, test all index params, and build
381
        method: search collection with the given vectors and tags, check the result
382
        expected: search status ok, and the length of the result is top_k
383
        '''
384
        top_k = 10
385
        new_tag = "new_tag"
386
        index_type = get_simple_index["index_type"]
387
        index_param = get_simple_index["index_param"]
388
        if index_type == IndexType.IVF_PQ:
389
            pytest.skip("Skip PQ")
390
        logging.getLogger().info(get_simple_index)
391
        status = connect.create_partition(collection, tag)
392
        status = connect.create_partition(collection, new_tag)
393
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
394
        new_vectors, new_ids = self.init_data(connect, collection, nb=6001, partition_tags=new_tag)
395
        status = connect.create_index(collection, index_type, index_param)
396
        query_vec = [vectors[0], new_vectors[0]]
397
        search_param = get_search_param(index_type)
398
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, new_tag],
399
                                        params=search_param)
400
        logging.getLogger().info(result)
401
        assert status.OK()
402
        assert len(result[0]) == min(len(vectors), top_k)
403
        assert check_result(result[0], ids[0])
404
        assert check_result(result[1], new_ids[0])
405
        assert result[0][0].distance <= epsilon
406
        assert result[1][0].distance <= epsilon
407
        status, result = connect.search(collection, top_k, query_vec, partition_tags=[new_tag], params=search_param)
408
        logging.getLogger().info(result)
409
        assert status.OK()
410
        assert len(result[0]) == min(len(vectors), top_k)
411
        assert check_result(result[1], new_ids[0])
412
        assert result[1][0].distance <= epsilon
413
414
    def test_search_l2_index_params_partition_F(self, connect, collection, get_simple_index):
415
        '''
416
        target: test basic search fuction, all the search params is corrent, test all index params, and build
417
        method: search collection with the given vectors and tags with "re" expr, check the result
418
        expected: search status ok, and the length of the result is top_k
419
        '''
420
        tag = "atag"
421
        new_tag = "new_tag"
422
        index_param = get_simple_index["index_param"]
423
        index_type = get_simple_index["index_type"]
424
        logging.getLogger().info(get_simple_index)
425
        if index_type == IndexType.IVF_PQ:
426
            pytest.skip("Skip PQ")
427
        status = connect.create_partition(collection, tag)
428
        status = connect.create_partition(collection, new_tag)
429
        vectors, ids = self.init_data(connect, collection, partition_tags=tag)
430
        new_vectors, new_ids = self.init_data(connect, collection, nb=6001, partition_tags=new_tag)
431
        status = connect.create_index(collection, index_type, index_param)
432
        query_vec = [vectors[0], new_vectors[0]]
433
        top_k = 10
434
        search_param = get_search_param(index_type)
435
        status, result = connect.search(collection, top_k, query_vec, partition_tags=["new(.*)"], params=search_param)
436
        logging.getLogger().info(result)
437
        assert status.OK()
438
        assert result[0][0].distance > epsilon
439
        assert result[1][0].distance <= epsilon
440
        status, result = connect.search(collection, top_k, query_vec, partition_tags=["(.*)tag"], params=search_param)
441
        logging.getLogger().info(result)
442
        assert status.OK()
443
        assert result[0][0].distance <= epsilon
444
        assert result[1][0].distance <= epsilon
445
446
    @pytest.mark.level(2)
447
    def test_search_ip_index_params(self, connect, ip_collection, get_simple_index):
448
        '''
449
        target: test basic search fuction, all the search params is corrent, test all index params, and build
450
        method: search with the given vectors, check the result
451
        expected: search status ok, and the length of the result is top_k
452
        '''
453
        top_k = 10
454
        index_param = get_simple_index["index_param"]
455
        index_type = get_simple_index["index_type"]
456
        logging.getLogger().info(get_simple_index)
457
        vectors, ids = self.init_data(connect, ip_collection)
458
        status = connect.create_index(ip_collection, index_type, index_param)
459
        query_vec = [vectors[0]]
460
        search_param = get_search_param(index_type)
461
        status, result = connect.search(ip_collection, top_k, query_vec, params=search_param)
462
        logging.getLogger().info(result)
463
        assert status.OK()
464
        assert len(result[0]) == min(len(vectors), top_k)
465
        assert check_result(result[0], ids[0])
466
        assert result[0][0].distance >= result[0][1].distance
467
468
    def test_search_ip_large_nq_index_params(self, connect, ip_collection, get_simple_index):
469
        '''
470
        target: test basic search fuction, all the search params is corrent, test all index params, and build
471
        method: search with the given vectors, check the result
472
        expected: search status ok, and the length of the result is top_k
473
        '''
474
        index_param = get_simple_index["index_param"]
475
        index_type = get_simple_index["index_type"]
476
        logging.getLogger().info(get_simple_index)
477
        if index_type in [IndexType.RNSG, IndexType.IVF_PQ]:
478
            pytest.skip("rnsg not support in ip, skip pq")
479
        vectors, ids = self.init_data(connect, ip_collection)
480
        status = connect.create_index(ip_collection, index_type, index_param)
481
        query_vec = []
482
        for i in range(1200):
483
            query_vec.append(vectors[i])
484
        top_k = 10
485
        search_param = get_search_param(index_type)
486
        status, result = connect.search(ip_collection, top_k, query_vec, params=search_param)
487
        logging.getLogger().info(result)
488
        assert status.OK()
489
        assert len(result[0]) == min(len(vectors), top_k)
490
        assert check_result(result[0], ids[0])
491
        assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance)
492
493
    @pytest.mark.level(2)
494
    def test_search_ip_index_params_partition(self, connect, ip_collection, get_simple_index):
495
        '''
496
        target: test basic search fuction, all the search params is corrent, test all index params, and build
497
        method: search with the given vectors, check the result
498
        expected: search status ok, and the length of the result is top_k
499
        '''
500
        top_k = 10
501
        index_param = get_simple_index["index_param"]
502
        index_type = get_simple_index["index_type"]
503
        logging.getLogger().info(index_param)
504
        if index_type in [IndexType.RNSG, IndexType.IVF_PQ]:
505
            pytest.skip("rnsg not support in ip, skip pq")
506
507
        status = connect.create_partition(ip_collection, tag)
508
        vectors, ids = self.init_data(connect, ip_collection)
509
        status = connect.create_index(ip_collection, index_type, index_param)
510
        query_vec = [vectors[0]]
511
        search_param = get_search_param(index_type)
512
        status, result = connect.search(ip_collection, top_k, query_vec, params=search_param)
513
        logging.getLogger().info(result)
514
        assert status.OK()
515
        assert len(result[0]) == min(len(vectors), top_k)
516
        assert check_result(result[0], ids[0])
517
        assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance)
518
        status, result = connect.search(ip_collection, top_k, query_vec, partition_tags=[tag], params=search_param)
519
        logging.getLogger().info(result)
520
        assert status.OK()
521
        assert len(result) == 0
522
523
    @pytest.mark.level(2)
524
    def test_search_ip_index_params_partition_A(self, connect, ip_collection, get_simple_index):
525
        '''
526
        target: test basic search fuction, all the search params is corrent, test all index params, and build
527
        method: search with the given vectors and tag, check the result
528
        expected: search status ok, and the length of the result is top_k
529
        '''
530
        top_k = 10
531
        index_param = get_simple_index["index_param"]
532
        index_type = get_simple_index["index_type"]
533
        logging.getLogger().info(index_param)
534
        if index_type in [IndexType.RNSG, IndexType.IVF_PQ]:
535
            pytest.skip("rnsg not support in ip, skip pq")
536
537
        status = connect.create_partition(ip_collection, tag)
538
        vectors, ids = self.init_data(connect, ip_collection, partition_tags=tag)
539
        status = connect.create_index(ip_collection, index_type, index_param)
540
        query_vec = [vectors[0]]
541
        search_param = get_search_param(index_type)
542
        status, result = connect.search(ip_collection, top_k, query_vec, partition_tags=[tag], params=search_param)
543
        logging.getLogger().info(result)
544
        assert status.OK()
545
        assert len(result[0]) == min(len(vectors), top_k)
546
        assert check_result(result[0], ids[0])
547
        assert result[0][0].distance >= 1 - gen_inaccuracy(result[0][0].distance)
548
549
    @pytest.mark.level(2)
550
    def test_search_vectors_without_connect(self, dis_connect, collection):
551
        '''
552
        target: test search vectors without connection
553
        method: use dis connected instance, call search method and check if search successfully
554
        expected: raise exception
555
        '''
556
        query_vectors = [vectors[0]]
557
        nprobe = 1
558
        with pytest.raises(Exception) as e:
559
            status, ids = dis_connect.search(collection, top_k, query_vectors)
560
561
    def test_search_collection_name_not_existed(self, connect, collection):
562
        '''
563
        target: search collection not existed
564
        method: search with the random collection_name, which is not in db
565
        expected: status not ok
566
        '''
567
        collection_name = gen_unique_str("not_existed_collection")
568
        nprobe = 1
569
        query_vecs = [vectors[0]]
570
        status, result = connect.search(collection_name, top_k, query_vecs)
571
        assert not status.OK()
572
573
    def test_search_collection_name_None(self, connect, collection):
574
        '''
575
        target: search collection that collection name is None
576
        method: search with the collection_name: None
577
        expected: status not ok
578
        '''
579
        collection_name = None
580
        nprobe = 1
581
        query_vecs = [vectors[0]]
582
        with pytest.raises(Exception) as e:
583
            status, result = connect.search(collection_name, top_k, query_vecs)
584
585
    def test_search_top_k_query_records(self, connect, collection):
586
        '''
587
        target: test search fuction, with search params: query_records
588
        method: search with the given query_records, which are subarrays of the inserted vectors
589
        expected: status ok and the returned vectors should be query_records
590
        '''
591
        top_k = 10
592
        vectors, ids = self.init_data(connect, collection)
593
        query_vecs = [vectors[0], vectors[55], vectors[99]]
594
        status, result = connect.search(collection, top_k, query_vecs)
595
        assert status.OK()
596
        assert len(result) == len(query_vecs)
597
        for i in range(len(query_vecs)):
598
            assert len(result[i]) == top_k
599
            assert result[i][0].distance <= epsilon
600
601
    def test_search_distance_l2_flat_index(self, connect, collection):
602
        '''
603
        target: search collection, and check the result: distance
604
        method: compare the return distance value with value computed with Euclidean
605
        expected: the return distance equals to the computed value
606
        '''
607
        nb = 2
608
        vectors, ids = self.init_data(connect, collection, nb=nb)
609
        query_vecs = [[0.50 for i in range(dim)]]
610
        distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0]))
611
        distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1]))
612
        status, result = connect.search(collection, top_k, query_vecs)
613
        assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(
614
            result[0][0].distance)
615
616
    def test_search_distance_ip_flat_index(self, connect, ip_collection):
617
        '''
618
        target: search ip_collection, and check the result: distance
619
        method: compare the return distance value with value computed with Inner product
620
        expected: the return distance equals to the computed value
621
        '''
622
        nb = 2
623
        nprobe = 1
624
        vectors, ids = self.init_data(connect, ip_collection, nb=nb)
625
        index_type = IndexType.FLAT
626
        index_param = {
627
            "nlist": 16384
628
        }
629
        connect.create_index(ip_collection, index_type, index_param)
630
        logging.getLogger().info(connect.get_index_info(ip_collection))
631
        query_vecs = [[0.50 for i in range(dim)]]
632
        distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0]))
633
        distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1]))
634
        search_param = get_search_param(index_type)
635
        status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
636
        assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
637
638 View Code Duplication
    def test_search_distance_jaccard_flat_index(self, connect, jac_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
639
        '''
640
        target: search ip_collection, and check the result: distance
641
        method: compare the return distance value with value computed with Inner product
642
        expected: the return distance equals to the computed value
643
        '''
644
        # from scipy.spatial import distance
645
        nprobe = 512
646
        int_vectors, vectors, ids = self.init_binary_data(connect, jac_collection, nb=2)
647
        index_type = IndexType.FLAT
648
        index_param = {
649
            "nlist": 16384
650
        }
651
        connect.create_index(jac_collection, index_type, index_param)
652
        logging.getLogger().info(connect.get_collection_info(jac_collection))
653
        logging.getLogger().info(connect.get_index_info(jac_collection))
654
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, jac_collection, nb=1, insert=False)
655
        distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
656
        distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
657
        search_param = get_search_param(index_type)
658
        status, result = connect.search(jac_collection, top_k, query_vecs, params=search_param)
659
        logging.getLogger().info(status)
660
        logging.getLogger().info(result)
661
        assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon
662
663
    def test_search_distance_jaccard_flat_index_metric_type(self, connect, jac_collection):
664
        '''
665
        target: search ip_collection, and check the result: distance
666
        method: compare the return distance value with value computed with HAMMING
667
        expected: the return distance equals to the computed value
668
        '''
669
        # from scipy.spatial import distance
670
        nprobe = 512
671
        int_vectors, vectors, ids = self.init_binary_data(connect, jac_collection, nb=2)
672
        index_type = IndexType.FLAT
673
        index_param = {
674
            "nlist": 16384
675
        }
676
        connect.create_index(jac_collection, index_type, index_param)
677
        logging.getLogger().info(connect.get_collection_info(jac_collection))
678
        logging.getLogger().info(connect.get_index_info(jac_collection))
679
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, jac_collection, nb=1, insert=False)
680
        distance_0 = hamming(query_int_vectors[0], int_vectors[0])
681
        distance_1 = hamming(query_int_vectors[0], int_vectors[1])
682
        search_param = get_search_param(index_type)
683
        search_param.update({"metric_type": MetricType.HAMMING.value})
684
        status, result = connect.search(jac_collection, top_k, query_vecs, params=search_param)
685
        assert status.OK()
686
        logging.getLogger().info(status)
687
        logging.getLogger().info(result)
688
        assert abs(result[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
689
690 View Code Duplication
    def test_search_distance_hamming_flat_index(self, connect, ham_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
691
        '''
692
        target: search ip_collection, and check the result: distance
693
        method: compare the return distance value with value computed with Inner product
694
        expected: the return distance equals to the computed value
695
        '''
696
        # from scipy.spatial import distance
697
        nprobe = 512
698
        int_vectors, vectors, ids = self.init_binary_data(connect, ham_collection, nb=2)
699
        index_type = IndexType.FLAT
700
        index_param = {
701
            "nlist": 16384
702
        }
703
        connect.create_index(ham_collection, index_type, index_param)
704
        logging.getLogger().info(connect.get_collection_info(ham_collection))
705
        logging.getLogger().info(connect.get_index_info(ham_collection))
706
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, ham_collection, nb=1, insert=False)
707
        distance_0 = hamming(query_int_vectors[0], int_vectors[0])
708
        distance_1 = hamming(query_int_vectors[0], int_vectors[1])
709
        search_param = get_search_param(index_type)
710
        status, result = connect.search(ham_collection, top_k, query_vecs, params=search_param)
711
        logging.getLogger().info(status)
712
        logging.getLogger().info(result)
713
        assert abs(result[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
714
715
    def test_search_distance_substructure_flat_index(self, connect, substructure_collection):
716
        '''
717
        target: search ip_collection, and check the result: distance
718
        method: compare the return distance value with value computed with Inner product
719
        expected: the return distance equals to the computed value
720
        '''
721
        # from scipy.spatial import distance
722
        nprobe = 512
723
        int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
724
        index_type = IndexType.FLAT
725
        index_param = {
726
            "nlist": 16384
727
        }
728
        connect.create_index(substructure_collection, index_type, index_param)
729
        logging.getLogger().info(connect.get_collection_info(substructure_collection))
730
        logging.getLogger().info(connect.get_index_info(substructure_collection))
731
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1,
732
                                                                       insert=False)
733
        distance_0 = substructure(query_int_vectors[0], int_vectors[0])
734
        distance_1 = substructure(query_int_vectors[0], int_vectors[1])
735
        search_param = get_search_param(index_type)
736
        status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param)
737
        logging.getLogger().info(status)
738
        logging.getLogger().info(result)
739
        assert len(result[0]) == 0
740
741 View Code Duplication
    def test_search_distance_substructure_flat_index_B(self, connect, substructure_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
742
        '''
743
        target: search ip_collection, and check the result: distance
744
        method: compare the return distance value with value computed with SUB 
745
        expected: the return distance equals to the computed value
746
        '''
747
        # from scipy.spatial import distance
748
        top_k = 3
749
        nprobe = 512
750
        int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
751
        index_type = IndexType.FLAT
752
        index_param = {
753
            "nlist": 16384
754
        }
755
        connect.create_index(substructure_collection, index_type, index_param)
756
        logging.getLogger().info(connect.get_collection_info(substructure_collection))
757
        logging.getLogger().info(connect.get_index_info(substructure_collection))
758
        query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2)
759
        search_param = get_search_param(index_type)
760
        status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param)
761
        logging.getLogger().info(status)
762
        logging.getLogger().info(result)
763
        assert len(result[0]) == 1
764
        assert len(result[1]) == 1
765
        assert result[0][0].distance <= epsilon
766
        assert result[0][0].id == ids[0]
767
        assert result[1][0].distance <= epsilon
768
        assert result[1][0].id == ids[1]
769
770
    def test_search_distance_superstructure_flat_index(self, connect, superstructure_collection):
771
        '''
772
        target: search ip_collection, and check the result: distance
773
        method: compare the return distance value with value computed with Inner product
774
        expected: the return distance equals to the computed value
775
        '''
776
        # from scipy.spatial import distance
777
        nprobe = 512
778
        int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
779
        index_type = IndexType.FLAT
780
        index_param = {
781
            "nlist": 16384
782
        }
783
        connect.create_index(superstructure_collection, index_type, index_param)
784
        logging.getLogger().info(connect.get_collection_info(superstructure_collection))
785
        logging.getLogger().info(connect.get_index_info(superstructure_collection))
786
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1,
787
                                                                       insert=False)
788
        distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
789
        distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
790
        search_param = get_search_param(index_type)
791
        status, result = connect.search(superstructure_collection, top_k, query_vecs, params=search_param)
792
        logging.getLogger().info(status)
793
        logging.getLogger().info(result)
794
        assert len(result[0]) == 0
795
796 View Code Duplication
    def test_search_distance_superstructure_flat_index_B(self, connect, superstructure_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
797
        '''
798
        target: search ip_collection, and check the result: distance
799
        method: compare the return distance value with value computed with SUPER
800
        expected: the return distance equals to the computed value
801
        '''
802
        # from scipy.spatial import distance
803
        top_k = 3
804
        nprobe = 512
805
        int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
806
        index_type = IndexType.FLAT
807
        index_param = {
808
            "nlist": 16384
809
        }
810
        connect.create_index(superstructure_collection, index_type, index_param)
811
        logging.getLogger().info(connect.get_collection_info(superstructure_collection))
812
        logging.getLogger().info(connect.get_index_info(superstructure_collection))
813
        query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2)
814
        search_param = get_search_param(index_type)
815
        status, result = connect.search(superstructure_collection, top_k, query_vecs, params=search_param)
816
        logging.getLogger().info(status)
817
        logging.getLogger().info(result)
818
        assert len(result[0]) == 2
819
        assert len(result[1]) == 2
820
        assert result[0][0].id in ids
821
        assert result[0][0].distance <= epsilon
822
        assert result[1][0].id in ids
823
        assert result[1][0].distance <= epsilon
824
825 View Code Duplication
    def test_search_distance_tanimoto_flat_index(self, connect, tanimoto_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
826
        '''
827
        target: search ip_collection, and check the result: distance
828
        method: compare the return distance value with value computed with Inner product
829
        expected: the return distance equals to the computed value
830
        '''
831
        # from scipy.spatial import distance
832
        nprobe = 512
833
        int_vectors, vectors, ids = self.init_binary_data(connect, tanimoto_collection, nb=2)
834
        index_type = IndexType.FLAT
835
        index_param = {
836
            "nlist": 16384
837
        }
838
        connect.create_index(tanimoto_collection, index_type, index_param)
839
        logging.getLogger().info(connect.get_collection_info(tanimoto_collection))
840
        logging.getLogger().info(connect.get_index_info(tanimoto_collection))
841
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, tanimoto_collection, nb=1, insert=False)
842
        distance_0 = tanimoto(query_int_vectors[0], int_vectors[0])
843
        distance_1 = tanimoto(query_int_vectors[0], int_vectors[1])
844
        search_param = get_search_param(index_type)
845
        status, result = connect.search(tanimoto_collection, top_k, query_vecs, params=search_param)
846
        logging.getLogger().info(status)
847
        logging.getLogger().info(result)
848
        assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon
849
850
    def test_search_distance_ip_index_params(self, connect, ip_collection, get_index):
851
        '''
852
        target: search collection, and check the result: distance
853
        method: compare the return distance value with value computed with Inner product
854
        expected: the return distance equals to the computed value
855
        '''
856
        top_k = 2
857
        nprobe = 1
858
        index_param = get_index["index_param"]
859
        index_type = get_index["index_type"]
860
        if index_type == IndexType.RNSG:
861
            pytest.skip("rnsg not support in ip")
862
        vectors, ids = self.init_data(connect, ip_collection, nb=2)
863
        connect.create_index(ip_collection, index_type, index_param)
864
        logging.getLogger().info(connect.get_index_info(ip_collection))
865
        query_vecs = [[0.50 for i in range(dim)]]
866
        search_param = get_search_param(index_type)
867
        status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
868
        logging.getLogger().debug(status)
869
        logging.getLogger().debug(result)
870
        distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0]))
871
        distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1]))
872
        assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
873
874
    # def test_search_concurrent(self, connect, collection):
875
    #     vectors, ids = self.init_data(connect, collection, nb=5000)
876
    #     thread_num = 50
877
    #     nq = 1
878
    #     top_k = 2
879
    #     threads = []
880
    #     query_vecs = vectors[:nq]
881
    #     def search(thread_number):
882
    #         for i in range(1000000):
883
    #             status, result = connect.search(collection, top_k, query_vecs, timeout=2)
884
    #             assert len(result) == len(query_vecs)
885
    #             assert status.OK()
886
    #             if i % 1000 == 0:
887
    #                 logging.getLogger().info("In %d, %d" % (thread_number, i))
888
    #         logging.getLogger().info("%d finished" % thread_number)
889
    #     # with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor:
890
    #     #     future_results = {executor.submit(
891
    #     #         search): i for i in range(1000000)}
892
    #     #     for future in concurrent.futures.as_completed(future_results):
893
    #     #         future.result()
894
    #     for i in range(thread_num):
895
    #         t = threading.Thread(target=search, args=(i, ))
896
    #         threads.append(t)
897
    #         t.start()
898
    #     for t in threads:
899
    #         t.join()
900
901 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
902
    @pytest.mark.timeout(30)
903
    def test_search_concurrent_multithreads(self, args):
904
        '''
905
        target: test concurrent search with multiprocessess
906
        method: search with 10 processes, each process uses dependent connection
907
        expected: status ok and the returned vectors should be query_records
908
        '''
909
        nb = 100
910
        top_k = 10
911
        threads_num = 4
912
        threads = []
913
        collection = gen_unique_str("test_search_concurrent_multiprocessing")
914
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
915
        param = {'collection_name': collection,
916
                 'dimension': dim,
917
                 'index_type': IndexType.FLAT,
918
                 'store_raw_vector': False}
919
        # create collection
920
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
921
        milvus.create_collection(param)
922
        vectors, ids = self.init_data(milvus, collection, nb=nb)
923
        query_vecs = vectors[nb // 2:nb]
924
925
        def search(milvus):
926
            status, result = milvus.search(collection, top_k, query_vecs)
927
            assert len(result) == len(query_vecs)
928
            for i in range(len(query_vecs)):
929
                assert result[i][0].id in ids
930
                assert result[i][0].distance == 0.0
931
932
        for i in range(threads_num):
933
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
934
            t = threading.Thread(target=search, args=(milvus,))
935
            threads.append(t)
936
            t.start()
937
            time.sleep(0.2)
938
        for t in threads:
939
            t.join()
940
941
    # TODO: enable
942 View Code Duplication
    @pytest.mark.timeout(30)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
943
    def _test_search_concurrent_multiprocessing(self, args):
944
        '''
945
        target: test concurrent search with multiprocessess
946
        method: search with 10 processes, each process uses dependent connection
947
        expected: status ok and the returned vectors should be query_records
948
        '''
949
        nb = 100
950
        top_k = 10
951
        process_num = 4
952
        processes = []
953
        collection = gen_unique_str("test_search_concurrent_multiprocessing")
954
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
955
        param = {'collection_name': collection,
956
                 'dimension': dim,
957
                 'index_type': IndexType.FLAT,
958
                 'store_raw_vector': False}
959
        # create collection
960
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
961
        milvus.create_collection(param)
962
        vectors, ids = self.init_data(milvus, collection, nb=nb)
963
        query_vecs = vectors[nb // 2:nb]
964
965
        def search(milvus):
966
            status, result = milvus.search(collection, top_k, query_vecs)
967
            assert len(result) == len(query_vecs)
968
            for i in range(len(query_vecs)):
969
                assert result[i][0].id in ids
970
                assert result[i][0].distance == 0.0
971
972
        for i in range(process_num):
973
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
974
            p = Process(target=search, args=(milvus,))
975
            processes.append(p)
976
            p.start()
977
            time.sleep(0.2)
978
        for p in processes:
979
            p.join()
980
981 View Code Duplication
    def test_search_multi_collection_L2(search, args):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
982
        '''
983
        target: test search multi collections of L2
984
        method: add vectors into 10 collections, and search
985
        expected: search status ok, the length of result
986
        '''
987
        num = 10
988
        top_k = 10
989
        collections = []
990
        idx = []
991
        for i in range(num):
992
            collection = gen_unique_str("test_add_multicollection_%d" % i)
993
            uri = "tcp://%s:%s" % (args["ip"], args["port"])
994
            param = {'collection_name': collection,
995
                     'dimension': dim,
996
                     'index_file_size': 10,
997
                     'metric_type': MetricType.L2}
998
            # create collection
999
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
1000
            milvus.create_collection(param)
1001
            status, ids = milvus.insert(collection, vectors)
1002
            assert status.OK()
1003
            assert len(ids) == len(vectors)
1004
            collections.append(collection)
1005
            idx.append(ids[0])
1006
            idx.append(ids[10])
1007
            idx.append(ids[20])
1008
            milvus.flush([collection])
1009
        query_vecs = [vectors[0], vectors[10], vectors[20]]
1010
        # start query from random collection
1011
        for i in range(num):
1012
            collection = collections[i]
1013
            status, result = milvus.search(collection, top_k, query_vecs)
0 ignored issues
show
introduced by
The variable milvus does not seem to be defined in case the for loop on line 991 is not entered. Are you sure this can never be the case?
Loading history...
1014
            assert status.OK()
1015
            assert len(result) == len(query_vecs)
1016
            for j in range(len(query_vecs)):
1017
                assert len(result[j]) == top_k
1018
            for j in range(len(query_vecs)):
1019
                assert check_result(result[j], idx[3 * i + j])
1020
1021 View Code Duplication
    def test_search_multi_collection_IP(search, args):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1022
        '''
1023
        target: test search multi collections of IP
1024
        method: add vectors into 10 collections, and search
1025
        expected: search status ok, the length of result
1026
        '''
1027
        num = 10
1028
        top_k = 10
1029
        collections = []
1030
        idx = []
1031
        for i in range(num):
1032
            collection = gen_unique_str("test_add_multicollection_%d" % i)
1033
            uri = "tcp://%s:%s" % (args["ip"], args["port"])
1034
            param = {'collection_name': collection,
1035
                     'dimension': dim,
1036
                     'index_file_size': 10,
1037
                     'metric_type': MetricType.L2}
1038
            # create collection
1039
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
1040
            milvus.create_collection(param)
1041
            status, ids = milvus.insert(collection, vectors)
1042
            assert status.OK()
1043
            assert len(ids) == len(vectors)
1044
            collections.append(collection)
1045
            idx.append(ids[0])
1046
            idx.append(ids[10])
1047
            idx.append(ids[20])
1048
            milvus.flush([collection])
1049
        query_vecs = [vectors[0], vectors[10], vectors[20]]
1050
        # start query from random collection
1051
        for i in range(num):
1052
            collection = collections[i]
1053
            status, result = milvus.search(collection, top_k, query_vecs)
0 ignored issues
show
introduced by
The variable milvus does not seem to be defined in case the for loop on line 1031 is not entered. Are you sure this can never be the case?
Loading history...
1054
            assert status.OK()
1055
            assert len(result) == len(query_vecs)
1056
            for j in range(len(query_vecs)):
1057
                assert len(result[j]) == top_k
1058
            for j in range(len(query_vecs)):
1059
                assert check_result(result[j], idx[3 * i + j])
1060
1061
    @pytest.fixture(params=MetricType)
1062
    def get_binary_metric_types(self, request):
1063
        if request.param == MetricType.INVALID:
1064
            pytest.skip(("metric type invalid"))
1065
        if request.param in [MetricType.L2, MetricType.IP]:
1066
            pytest.skip(("L2 and IP not support in binary"))
1067
        return request.param
1068
1069
    # 4678 and # 4683
1070
    def test_search_binary_dim_not_power_of_2(self, connect, get_binary_metric_types):
1071
        metric = get_binary_metric_types
1072
        collection = gen_unique_str(collection_id)
1073
        dim = 200
1074
        top_k = 1
1075
        param = {'collection_name': collection,
1076
                 'dimension': dim,
1077
                 'index_file_size': 10,
1078
                 'metric_type': metric}
1079
        status = connect.create_collection(param)
1080
        assert status.OK()
1081
        int_vectors, vectors, ids = self.init_binary_data(connect, collection, nb=1000, dim=dim)
1082
        search_param = get_search_param(IndexType.FLAT)
1083
        status, result = connect.search(collection, top_k, vectors[:1], params=search_param)
1084
        assert status.OK()
1085
        assert result[0][0].id in ids
1086
        assert result[0][0].distance == 0.0
1087
1088
    @pytest.fixture(params=MetricType)
1089
    def get_metric_types(self, request):
1090
        if request.param == MetricType.INVALID:
1091
            pytest.skip(("metric type invalid"))
1092
        if request.param not in [MetricType.L2, MetricType.IP]:
1093
            pytest.skip(("L2 and IP not support in binary"))
1094
        return request.param
1095
1096
    def test_search_float_dim_not_power_of_2(self, connect, get_metric_types):
1097
        metric = get_metric_types
1098
        collection = gen_unique_str(collection_id)
1099
        dim = 200
1100
        top_k = 1
1101
        param = {'collection_name': collection,
1102
                 'dimension': dim,
1103
                 'index_file_size': 10,
1104
                 'metric_type': metric}
1105
        status = connect.create_collection(param)
1106
        assert status.OK()
1107
        vectors, ids = self.init_data(connect, collection, nb=1000, dim=dim)
1108
        search_param = get_search_param(IndexType.FLAT)
1109
        status, result = connect.search(collection, top_k, vectors[:1], params=search_param)
1110
        assert status.OK()
1111
        assert result[0][0].id in ids
1112
1113
"""
1114
******************************************************************
1115
#  The following cases are used to test `search_vectors` function 
1116
#  with invalid collection_name top-k / nprobe / query_range
1117
******************************************************************
1118
"""
1119
1120
1121
class TestSearchParamsInvalid(object):
1122
    nlist = 16384
1123
    index_type = IndexType.IVF_SQ8
1124
    index_param = {"nlist": nlist}
1125
    logging.getLogger().info(index_param)
1126
1127
    def init_data(self, connect, collection, nb=6000):
1128
        '''
1129
        Generate vectors and add it in collection, before search vectors
1130
        '''
1131
        global vectors
1132
        if nb == 6000:
1133
            insert = vectors
1134
        else:
1135
            insert = gen_vectors(nb, dim)
1136
        status, ids = connect.insert(collection, insert)
1137
        connect.flush([collection])
1138
        return insert, ids
1139
1140
    """
1141
    Test search collection with invalid collection names
1142
    """
1143
1144
    @pytest.fixture(
1145
        scope="function",
1146
        params=gen_invalid_collection_names()
1147
    )
1148
    def get_collection_name(self, request):
1149
        yield request.param
1150
1151
    @pytest.mark.level(2)
1152
    def test_search_with_invalid_collectionname(self, connect, get_collection_name):
1153
        collection_name = get_collection_name
1154
        logging.getLogger().info(collection_name)
1155
        nprobe = 1
1156
        query_vecs = gen_vectors(1, dim)
1157
        status, result = connect.search(collection_name, top_k, query_vecs)
1158
        assert not status.OK()
1159
1160
    @pytest.mark.level(1)
1161
    def test_search_with_invalid_tag_format(self, connect, collection):
1162
        nprobe = 1
1163
        query_vecs = gen_vectors(1, dim)
1164
        with pytest.raises(Exception) as e:
1165
            status, result = connect.search(collection, top_k, query_vecs, partition_tags="tag")
1166
            logging.getLogger().debug(result)
1167
1168
    @pytest.mark.level(1)
1169
    def test_search_with_tag_not_existed(self, connect, collection):
1170
        nprobe = 1
1171
        query_vecs = gen_vectors(1, dim)
1172
        status, result = connect.search(collection, top_k, query_vecs, partition_tags=["tag"])
1173
        logging.getLogger().info(result)
1174
        assert not status.OK()
1175
1176
    """
1177
    Test search collection with invalid top-k
1178
    """
1179
1180
    @pytest.fixture(
1181
        scope="function",
1182
        params=gen_invalid_top_ks()
1183
    )
1184
    def get_top_k(self, request):
1185
        yield request.param
1186
1187 View Code Duplication
    @pytest.mark.level(1)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1188
    def test_search_with_invalid_top_k(self, connect, collection, get_top_k):
1189
        '''
1190
        target: test search fuction, with the wrong top_k
1191
        method: search with top_k
1192
        expected: raise an error, and the connection is normal
1193
        '''
1194
        top_k = get_top_k
1195
        logging.getLogger().info(top_k)
1196
        nprobe = 1
1197
        query_vecs = gen_vectors(1, dim)
1198
        if isinstance(top_k, int):
1199
            status, result = connect.search(collection, top_k, query_vecs)
1200
            assert not status.OK()
1201
        else:
1202
            with pytest.raises(Exception) as e:
1203
                status, result = connect.search(collection, top_k, query_vecs)
1204
1205 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1206
    def test_search_with_invalid_top_k_ip(self, connect, ip_collection, get_top_k):
1207
        '''
1208
        target: test search fuction, with the wrong top_k
1209
        method: search with top_k
1210
        expected: raise an error, and the connection is normal
1211
        '''
1212
        top_k = get_top_k
1213
        logging.getLogger().info(top_k)
1214
        nprobe = 1
1215
        query_vecs = gen_vectors(1, dim)
1216
        if isinstance(top_k, int):
1217
            status, result = connect.search(ip_collection, top_k, query_vecs)
1218
            assert not status.OK()
1219
        else:
1220
            with pytest.raises(Exception) as e:
1221
                status, result = connect.search(ip_collection, top_k, query_vecs)
1222
1223
    """
1224
    Test search collection with invalid nprobe
1225
    """
1226
1227
    @pytest.fixture(
1228
        scope="function",
1229
        params=gen_invalid_nprobes()
1230
    )
1231
    def get_nprobes(self, request):
1232
        yield request.param
1233
1234 View Code Duplication
    @pytest.mark.level(1)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1235
    def test_search_with_invalid_nprobe(self, connect, collection, get_nprobes):
1236
        '''
1237
        target: test search fuction, with the wrong nprobe
1238
        method: search with nprobe
1239
        expected: raise an error, and the connection is normal
1240
        '''
1241
        index_type = IndexType.IVF_SQ8
1242
        index_param = {"nlist": 16384}
1243
        connect.create_index(collection, index_type, index_param)
1244
        nprobe = get_nprobes
1245
        search_param = {"nprobe": nprobe}
1246
        logging.getLogger().info(nprobe)
1247
        query_vecs = gen_vectors(1, dim)
1248
        # if isinstance(nprobe, int):
1249
        status, result = connect.search(collection, top_k, query_vecs, params=search_param)
1250
        assert not status.OK()
1251
        # else:
1252
        #     with pytest.raises(Exception) as e:
1253
        #         status, result = connect.search(collection, top_k, query_vecs, params=search_param)
1254
1255 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1256
    def test_search_with_invalid_nprobe_ip(self, connect, ip_collection, get_nprobes):
1257
        '''
1258
        target: test search fuction, with the wrong top_k
1259
        method: search with top_k
1260
        expected: raise an error, and the connection is normal
1261
        '''
1262
        index_type = IndexType.IVF_SQ8
1263
        index_param = {"nlist": 16384}
1264
        connect.create_index(ip_collection, index_type, index_param)
1265
        nprobe = get_nprobes
1266
        search_param = {"nprobe": nprobe}
1267
        logging.getLogger().info(nprobe)
1268
        query_vecs = gen_vectors(1, dim)
1269
1270
        # if isinstance(nprobe, int):
1271
        status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
1272
        assert not status.OK()
1273
        # else:
1274
        #     with pytest.raises(Exception) as e:
1275
        #         status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
1276
1277
    def test_search_with_2049_nprobe(self, connect, collection):
1278
        '''
1279
        target: test search function, with 2049 nprobe in GPU mode
1280
        method: search with nprobe
1281
        expected: status not ok
1282
        '''
1283
        if str(connect._cmd("mode")[1]) == "CPU":
1284
            pytest.skip("Only support GPU mode")
1285
        for index in gen_simple_index():
1286
            if index["index_type"] in [IndexType.IVF_PQ, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]:
1287
                index_type = index["index_type"]
1288
                index_param = index["index_param"]
1289
                self.init_data(connect, collection)
1290
                connect.create_index(collection, index_type, index_param)
1291
                nprobe = 2049
1292
                search_param = {"nprobe": nprobe}
1293
                query_vecs = gen_vectors(nprobe, dim)
1294
                status, result = connect.search(collection, top_k, query_vecs, params=search_param)
1295
                assert status.OK()
1296
1297 View Code Duplication
    @pytest.fixture(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1298
        scope="function",
1299
        params=gen_simple_index()
1300
    )
1301
    def get_simple_index(self, request, connect):
1302
        if str(connect._cmd("mode")[1]) == "CPU":
1303
            if request.param["index_type"] == IndexType.IVF_SQ8H:
1304
                pytest.skip("sq8h not support in CPU mode")
1305
        if str(connect._cmd("mode")[1]) == "GPU":
1306
            if request.param["index_type"] == IndexType.IVF_PQ:
1307
                pytest.skip("ivfpq not support in GPU mode")
1308
        return request.param
1309
1310
    def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
1311
        '''
1312
        target: test search fuction, with empty search params
1313
        method: search with params
1314
        expected: search status not ok, and the connection is normal
1315
        '''
1316
        if args["handler"] == "HTTP":
1317
            pytest.skip("skip in http mode")
1318
        index_type = get_simple_index["index_type"]
1319
        index_param = get_simple_index["index_param"]
1320
        connect.create_index(collection, index_type, index_param)
1321
        query_vecs = gen_vectors(1, dim)
1322
        status, result = connect.search(collection, top_k, query_vecs, params={})
1323
1324
        if index_type == IndexType.FLAT:
1325
            assert status.OK()
1326
        else:
1327
            assert not status.OK()
1328
1329
    @pytest.fixture(
1330
        scope="function",
1331
        params=gen_invaild_search_params()
1332
    )
1333
    def get_invalid_search_param(self, request, connect):
1334
        if str(connect._cmd("mode")[1]) == "CPU":
1335
            if request.param["index_type"] == IndexType.IVF_SQ8H:
1336
                pytest.skip("sq8h not support in CPU mode")
1337
        if str(connect._cmd("mode")[1]) == "GPU":
1338
            if request.param["index_type"] == IndexType.IVF_PQ:
1339
                pytest.skip("ivfpq not support in GPU mode")
1340
        return request.param
1341
1342
    def test_search_with_invalid_params(self, connect, collection, get_invalid_search_param):
1343
        '''
1344
        target: test search fuction, with invalid search params
1345
        method: search with params
1346
        expected: search status not ok, and the connection is normal
1347
        '''
1348
        index_type = get_invalid_search_param["index_type"]
1349
        search_param = get_invalid_search_param["search_param"]
1350
        for index in gen_simple_index():
1351
            if index_type == index["index_type"]:
1352
                connect.create_index(collection, index_type, index["index_param"])
1353
        query_vecs = gen_vectors(1, dim)
1354
        status, result = connect.search(collection, top_k, query_vecs, params=search_param)
1355
        assert not status.OK()
1356
1357
1358
def check_result(result, id):
1359
    if len(result) >= 5:
1360
        return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
1361
    else:
1362
        return id in (i.id for i in result)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
1363