Passed
Push — master ( fd4969...54df52 )
by
unknown
01:50
created

TestSearchInvalid._test_search_with_invalid_params()   A

Complexity

Conditions 3

Size

Total Lines 16
Code Lines 11

Duplication

Lines 16
Ratio 100 %

Importance

Changes 0
Metric Value
cc 3
eloc 11
nop 5
dl 16
loc 16
rs 9.85
c 0
b 0
f 0
1
import time
2
import pdb
3
import copy
4
import threading
5
import logging
6
from multiprocessing import Pool, Process
7
import pytest
8
import numpy as np
9
10
from milvus import DataType
11
from utils import *
12
13
dim = 128
14
segment_size = 10
15
top_k_limit = 2048
16
collection_id = "search"
17
tag = "1970-01-01"
18
insert_interval_time = 1.5
19
nb = 6000
20
top_k = 10
21
nprobe = 1
22
epsilon = 0.001
23
field_name = "float_vector"
24
default_index_name = "insert_index"
25
default_fields = gen_default_fields() 
26
search_param = {"nprobe": 1}
27
entity = gen_entities(1, is_normal=True)
28
raw_vector, binary_entity = gen_binary_entities(1)
29
entities = gen_entities(nb, is_normal=True)
30
raw_vectors, binary_entities = gen_binary_entities(nb)
31
query, query_vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, 1)
32
# query = {
33
#     "bool": {
34
#         "must": [
35
#             {"term": {"A": {"values": [1, 2, 5]}}},
36
#             {"range": {"B": {"ranges": {"GT": 1, "LT": 100}}}},
37
#             {"vector": {"Vec": {"topk": 10, "query": vec[: 1], "params": {"index_name": "IVFFLAT", "nprobe": 10}}}}
38
#         ],
39
#     },
40
# }
41
def init_data(connect, collection, nb=6000, partition_tags=None):
42
    '''
43
    Generate entities and add it in collection
44
    '''
45
    global entities
46
    if nb == 6000:
47
        insert_entities = entities
48
    else:  
49
        insert_entities = gen_entities(nb, is_normal=True)
50
    if partition_tags is None:
51
        ids = connect.insert(collection, insert_entities)
52
    else:
53
        ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
54
    connect.flush([collection])
55
    return insert_entities, ids
56
57
def init_binary_data(connect, collection, nb=6000, insert=True, partition_tags=None):
58
    '''
59
    Generate entities and add it in collection
60
    '''
61
    ids = []
62
    global binary_entities
63
    global raw_vectors
64
    if nb == 6000:
65
        insert_entities = binary_entities
66
        insert_raw_vectors = raw_vectors
67
    else:  
68
        insert_raw_vectors, insert_entities = gen_binary_entities(nb)
69
    if insert is True:
70
        if partition_tags is None:
71
            ids = connect.insert(collection, insert_entities)
72
        else:
73
            ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
74
        connect.flush([collection])
75
    return insert_raw_vectors, insert_entities, ids
76
77
78
class TestSearchBase:
79
80
81
    """
82
    generate valid create_index params
83
    """
84
    @pytest.fixture(
85
        scope="function",
86
        params=gen_index()
87
    )
88
    def get_index(self, request, connect):
89
        if str(connect._cmd("mode")) == "CPU":
90
            if request.param["index_type"] in index_cpu_not_support():
91
                pytest.skip("sq8h not support in CPU mode")
92
        return request.param
93
94
    @pytest.fixture(
95
        scope="function",
96
        params=gen_simple_index()
97
    )
98
    def get_simple_index(self, request, connect):
99
        if str(connect._cmd("mode")) == "CPU":
100
            if request.param["index_type"] in index_cpu_not_support():
101
                pytest.skip("sq8h not support in CPU mode")
102
        return request.param
103
104
    @pytest.fixture(
105
        scope="function",
106
        params=gen_simple_index()
107
    )
108
    def get_jaccard_index(self, request, connect):
109
        logging.getLogger().info(request.param)
110
        if request.param["index_type"] in binary_support():
111
            return request.param
112
        else:
113
            pytest.skip("Skip index Temporary")
114
115
    @pytest.fixture(
116
        scope="function",
117
        params=gen_simple_index()
118
    )
119
    def get_hamming_index(self, request, connect):
120
        logging.getLogger().info(request.param)
121
        if request.param["index_type"] in binary_support():
122
            return request.param
123
        else:
124
            pytest.skip("Skip index Temporary")
125
126
    @pytest.fixture(
127
        scope="function",
128
        params=gen_simple_index()
129
    )
130
    def get_structure_index(self, request, connect):
131
        logging.getLogger().info(request.param)
132
        if request.param["index_type"] == "FLAT":
133
            return request.param
134
        else:
135
            pytest.skip("Skip index Temporary")
136
137
    """
138
    generate top-k params
139
    """
140
    @pytest.fixture(
141
        scope="function",
142
        params=[1, 10, 2049]
143
    )
144
    def get_top_k(self, request):
145
        yield request.param
146
147
    @pytest.fixture(
148
        scope="function",
149
        params=[1, 10, 1100]
150
    )
151
    def get_nq(self, request):
152
        yield request.param
153
154
    def test_search_flat(self, connect, collection, get_top_k, get_nq):
155
        '''
156
        target: test basic search fuction, all the search params is corrent, change top-k value
157
        method: search with the given vectors, check the result
158
        expected: the length of the result is top_k
159
        '''
160
        top_k = get_top_k
161
        nq = get_nq
162
        entities, ids = init_data(connect, collection)
163
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq)
164
        if top_k <= top_k_limit:
165
            res = connect.search(collection, query)
166
            assert len(res[0]) == top_k
167
            assert res[0]._distances[0] <= epsilon
168
            assert check_id_result(res[0], ids[0])
169
        else:
170
            with pytest.raises(Exception) as e:
171
                res = connect.search(collection, query)
172
173
    def test_search_field(self, connect, collection, get_top_k, get_nq):
174
        '''
175
        target: test basic search fuction, all the search params is corrent, change top-k value
176
        method: search with the given vectors, check the result
177
        expected: the length of the result is top_k
178
        '''
179
        top_k = get_top_k
180
        nq = get_nq
181
        entities, ids = init_data(connect, collection)
182
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq)
183
        if top_k <= top_k_limit:
184
            res = connect.search(collection, query, fields=["vector"])
185
            assert len(res[0]) == top_k
186
            assert res[0]._distances[0] <= epsilon
187
            assert check_id_result(res[0], ids[0])
188
            # TODO
189
            res = connect.search(collection, query, fields=["float"])
190
            # TODO
191
        else:
192
            with pytest.raises(Exception) as e:
193
                res = connect.search(collection, query)
194
195
    @pytest.mark.level(2)
196
    def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
197
        '''
198
        target: test basic search fuction, all the search params is corrent, test all index params, and build
199
        method: search with the given vectors, check the result
200
        expected: the length of the result is top_k
201
        '''
202
        top_k = get_top_k
203
        nq = get_nq
204
205
        index_type = get_simple_index["index_type"]
206
        if index_type == "IVF_PQ":
207
            pytest.skip("Skip PQ")
208
        entities, ids = init_data(connect, collection)
209
        connect.create_index(collection, field_name, default_index_name, get_simple_index)
210
        search_param = get_search_param(index_type)
211
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
212
        if top_k > top_k_limit:
213
            with pytest.raises(Exception) as e:
214
                res = connect.search(collection, query)
215
        else:
216
            res = connect.search(collection, query)
217
            assert len(res) == nq
218
            assert len(res[0]) >= top_k
219
            assert res[0]._distances[0] < epsilon
220
            assert check_id_result(res[0], ids[0])
221
222
    def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
223
        '''
224
        target: test basic search fuction, all the search params is corrent, test all index params, and build
225
        method: add vectors into collection, search with the given vectors, check the result
226
        expected: the length of the result is top_k, search collection with partition tag return empty
227
        '''
228
        top_k = get_top_k
229
        nq = get_nq
230
231
        index_type = get_simple_index["index_type"]
232
        if index_type == "IVF_PQ":
233
            pytest.skip("Skip PQ")
234
        connect.create_partition(collection, tag)
235
        entities, ids = init_data(connect, collection)
236
        connect.create_index(collection, field_name, default_index_name, get_simple_index)
237
        search_param = get_search_param(index_type)
238
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
239
        if top_k > top_k_limit:
240
            with pytest.raises(Exception) as e:
241
                res = connect.search(collection, query)
242
        else:
243
            res = connect.search(collection, query)
244
            assert len(res) == nq
245
            assert len(res[0]) >= top_k
246
            assert res[0]._distances[0] < epsilon
247
            assert check_id_result(res[0], ids[0])
248
            res = connect.search(collection, query, partition_tags=[tag])
249
            assert len(res) == nq
250
251
    def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq):
252
        '''
253
        target: test basic search fuction, all the search params is corrent, test all index params, and build
254
        method: search with the given vectors, check the result
255
        expected: the length of the result is top_k
256
        '''
257
        top_k = get_top_k
258
        nq = get_nq
259
260
        index_type = get_simple_index["index_type"]
261
        if index_type == "IVF_PQ":
262
            pytest.skip("Skip PQ")
263
        connect.create_partition(collection, tag)
264
        entities, ids = init_data(connect, collection, partition_tags=tag)
265
        connect.create_index(collection, field_name, default_index_name, get_simple_index)
266
        search_param = get_search_param(index_type)
267
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
268
        for tags in [[tag], [tag, "new_tag"]]:
269
            if top_k > top_k_limit:
270
                with pytest.raises(Exception) as e:
271
                    res = connect.search(collection, query, partition_tags=tags)
272
            else:
273
                res = connect.search(collection, query, partition_tags=tags)
274
                assert len(res) == nq
275
                assert len(res[0]) >= top_k
276
                assert res[0]._distances[0] < epsilon
277
                assert check_id_result(res[0], ids[0])
278
279
    @pytest.mark.level(2)
280
    def test_search_index_partition_C(self, connect, collection, get_top_k, get_nq):
281
        '''
282
        target: test basic search fuction, all the search params is corrent, test all index params, and build
283
        method: search with the given vectors and tag (tag name not existed in collection), check the result
284
        expected: error raised
285
        '''
286
        top_k = get_top_k
287
        nq = get_nq
288
        entities, ids = init_data(connect, collection)
289
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq)
290
        if top_k > top_k_limit:
291
            with pytest.raises(Exception) as e:
292
                res = connect.search(collection, query, partition_tags=["new_tag"])
293
        else:
294
            res = connect.search(collection, query, partition_tags=["new_tag"])
295
            assert len(res) == nq
296
            assert len(res[0]) == 0
297
298 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
299
    def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k):
300
        '''
301
        target: test basic search fuction, all the search params is corrent, test all index params, and build
302
        method: search collection with the given vectors and tags, check the result
303
        expected: the length of the result is top_k
304
        '''
305
        top_k = get_top_k
306
        nq = 2
307
        new_tag = "new_tag"
308
        index_type = get_simple_index["index_type"]
309
        if index_type == "IVF_PQ":
310
            pytest.skip("Skip PQ")
311
        connect.create_partition(collection, tag)
312
        connect.create_partition(collection, new_tag)
313
        entities, ids = init_data(connect, collection, partition_tags=tag)
314
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
315
        connect.create_index(collection, field_name, default_index_name, get_simple_index)
316
        search_param = get_search_param(index_type)
317
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
318
        if top_k > top_k_limit:
319
            with pytest.raises(Exception) as e:
320
                res = connect.search(collection, query)
321
        else:
322
            res = connect.search(collection, query)
323
            assert check_id_result(res[0], ids[0])
324
            assert not check_id_result(res[1], new_ids[0])
325
            assert res[0]._distances[0] < epsilon
326
            assert res[1]._distances[0] < epsilon
327
            res = connect.search(collection, query, partition_tags=["new_tag"])
328
            assert res[0]._distances[0] > epsilon
329
            assert res[1]._distances[0] > epsilon
330
331
    # TODO:
332 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
333
    def _test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
334
        '''
335
        target: test basic search fuction, all the search params is corrent, test all index params, and build
336
        method: search collection with the given vectors and tags, check the result
337
        expected: the length of the result is top_k
338
        '''
339
        top_k = get_top_k
340
        nq = 2
341
        tag = "tag"
342
        new_tag = "new_tag"
343
        index_type = get_simple_index["index_type"]
344
        if index_type == "IVF_PQ":
345
            pytest.skip("Skip PQ")
346
        connect.create_partition(collection, tag)
347
        connect.create_partition(collection, new_tag)
348
        entities, ids = init_data(connect, collection, partition_tags=tag)
349
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
350
        connect.create_index(collection, field_name, default_index_name, get_simple_index)
351
        search_param = get_search_param(index_type)
352
        query, vecs = gen_query_vectors_inside_entities(field_name, new_entities, top_k, nq, search_params=search_param)
353
        if top_k > top_k_limit:
354
            with pytest.raises(Exception) as e:
355
                res = connect.search(collection, query)
356
        else:
357
            res = connect.search(collection, query, partition_tags=["(.*)tag"])
358
            assert not check_id_result(res[0], ids[0])
359
            assert check_id_result(res[1], new_ids[0])
360
            assert res[0]._distances[0] > epsilon
361
            assert res[1]._distances[0] < epsilon
362
            res = connect.search(collection, query, partition_tags=["new(.*)"])
363
            assert res[0]._distances[0] > epsilon
364
            assert res[1]._distances[0] < epsilon
365
366
    # 
367
    # test for ip metric
368
    # 
369
    @pytest.mark.level(2)
370
    def test_search_ip_flat(self, connect, ip_collection, get_simple_index, get_top_k, get_nq):
371
        '''
372
        target: test basic search fuction, all the search params is corrent, change top-k value
373
        method: search with the given vectors, check the result
374
        expected: the length of the result is top_k
375
        '''
376
        top_k = get_top_k
377
        nq = get_nq
378
        entities, ids = init_data(connect, ip_collection)
379
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq)
380
        if top_k <= top_k_limit:
381
            res = connect.search(ip_collection, query)
382
            assert len(res[0]) == top_k
383
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
384
            assert check_id_result(res[0], ids[0])
385
        else:
386
            with pytest.raises(Exception) as e:
387
                res = connect.search(ip_collection, query)
388
389
    def test_search_ip_after_index(self, connect, ip_collection, get_simple_index, get_top_k, get_nq):
390
        '''
391
        target: test basic search fuction, all the search params is corrent, test all index params, and build
392
        method: search with the given vectors, check the result
393
        expected: the length of the result is top_k
394
        '''
395
        top_k = get_top_k
396
        nq = get_nq
397
398
        index_type = get_simple_index["index_type"]
399
        if index_type == "IVF_PQ":
400
            pytest.skip("Skip PQ")
401
        entities, ids = init_data(connect, ip_collection)
402
        connect.create_index(ip_collection, field_name, default_index_name, get_simple_index)
403
        search_param = get_search_param(index_type)
404
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
405
        if top_k > top_k_limit:
406
            with pytest.raises(Exception) as e:
407
                res = connect.search(ip_collection, query)
408
        else:
409
            res = connect.search(ip_collection, query)
410
            assert len(res) == nq
411
            assert len(res[0]) >= top_k
412
            assert check_id_result(res[0], ids[0])
413
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
414
415
    @pytest.mark.level(2)
416
    def test_search_ip_index_partition(self, connect, ip_collection, get_simple_index, get_top_k, get_nq):
417
        '''
418
        target: test basic search fuction, all the search params is corrent, test all index params, and build
419
        method: add vectors into collection, search with the given vectors, check the result
420
        expected: the length of the result is top_k, search collection with partition tag return empty
421
        '''
422
        top_k = get_top_k
423
        nq = get_nq
424
425
        index_type = get_simple_index["index_type"]
426
        if index_type == "IVF_PQ":
427
            pytest.skip("Skip PQ")
428
        connect.create_partition(ip_collection, tag)
429
        entities, ids = init_data(connect, ip_collection)
430
        connect.create_index(ip_collection, field_name, default_index_name, get_simple_index)
431
        search_param = get_search_param(index_type)
432
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
433
        if top_k > top_k_limit:
434
            with pytest.raises(Exception) as e:
435
                res = connect.search(ip_collection, query)
436
        else:
437
            res = connect.search(ip_collection, query)
438
            assert len(res) == nq
439
            assert len(res[0]) >= top_k
440
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
441
            assert check_id_result(res[0], ids[0])
442
            res = connect.search(ip_collection, query, partition_tags=[tag])
443
            assert len(res) == nq
444
445
    @pytest.mark.level(2)
446
    def test_search_ip_index_partitions(self, connect, ip_collection, get_simple_index, get_top_k):
447
        '''
448
        target: test basic search fuction, all the search params is corrent, test all index params, and build
449
        method: search ip_collection with the given vectors and tags, check the result
450
        expected: the length of the result is top_k
451
        '''
452
        top_k = get_top_k
453
        nq = 2
454
        new_tag = "new_tag"
455
        index_type = get_simple_index["index_type"]
456
        if index_type == "IVF_PQ":
457
            pytest.skip("Skip PQ")
458
        connect.create_partition(ip_collection, tag)
459
        connect.create_partition(ip_collection, new_tag)
460
        entities, ids = init_data(connect, ip_collection, partition_tags=tag)
461
        new_entities, new_ids = init_data(connect, ip_collection, nb=6001, partition_tags=new_tag)
462
        connect.create_index(ip_collection, field_name, default_index_name, get_simple_index)
463
        search_param = get_search_param(index_type)
464
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
465
        if top_k > top_k_limit:
466
            with pytest.raises(Exception) as e:
467
                res = connect.search(ip_collection, query)
468
        else:
469
            res = connect.search(ip_collection, query)
470
            assert check_id_result(res[0], ids[0])
471
            assert not check_id_result(res[1], new_ids[0])
472
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
473
            assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
474
            res = connect.search(ip_collection, query, partition_tags=["new_tag"])
475
            assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0])
476
            # TODO:
477
            # assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
478
479
    @pytest.mark.level(2)
480
    def test_search_without_connect(self, dis_connect, collection):
481
        '''
482
        target: test search vectors without connection
483
        method: use dis connected instance, call search method and check if search successfully
484
        expected: raise exception
485
        '''
486
        with pytest.raises(Exception) as e:
487
            res = dis_connect.search(collection, query)
488
489
    def test_search_collection_name_not_existed(self, connect):
490
        '''
491
        target: search collection not existed
492
        method: search with the random collection_name, which is not in db
493
        expected: status not ok
494
        '''
495
        collection_name = gen_unique_str(collection_id)
496
        with pytest.raises(Exception) as e:
497
            res = connect.search(collection_name, query)
498
499 View Code Duplication
    def test_search_distance_l2(self, connect, collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
500
        '''
501
        target: search collection, and check the result: distance
502
        method: compare the return distance value with value computed with Euclidean
503
        expected: the return distance equals to the computed value
504
        '''
505
        nq = 2
506
        search_param = {"nprobe" : 1}
507
        entities, ids = init_data(connect, collection, nb=nq)
508
        query, vecs = gen_query_vectors_rand_entities(field_name, entities, top_k, nq, search_params=search_param)
509
        inside_query, inside_vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
510
        distance_0 = l2(vecs[0], inside_vecs[0])
511
        distance_1 = l2(vecs[0], inside_vecs[1])
512
        res = connect.search(collection, query)
513
        assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
514
515
    # TODO: distance problem
516 View Code Duplication
    def _test_search_distance_l2_after_index(self, connect, collection, get_simple_index):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
517
        '''
518
        target: search collection, and check the result: distance
519
        method: compare the return distance value with value computed with Inner product
520
        expected: the return distance equals to the computed value
521
        '''
522
        index_type = get_simple_index["index_type"]
523
        nq = 2
524
        entities, ids = init_data(connect, collection)
525
        connect.create_index(collection, field_name, default_index_name, get_simple_index)
526
        search_param = get_search_param(index_type)
527
        query, vecs = gen_query_vectors_rand_entities(field_name, entities, top_k, nq, search_params=search_param)
528
        inside_vecs = entities[-1]["values"]
529
        min_distance = 1.0
530
        for i in range(nb):
531
            tmp_dis = l2(vecs[0], inside_vecs[i])
532
            if min_distance > tmp_dis:
533
                min_distance = tmp_dis
534
        res = connect.search(collection, query)
535
        assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= gen_inaccuracy(res[0]._distances[0])
536
537 View Code Duplication
    def test_search_distance_ip(self, connect, ip_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
538
        '''
539
        target: search ip_collection, and check the result: distance
540
        method: compare the return distance value with value computed with Inner product
541
        expected: the return distance equals to the computed value
542
        '''
543
        nq = 2
544
        search_param = {"nprobe" : 1}
545
        entities, ids = init_data(connect, ip_collection, nb=nq)
546
        query, vecs = gen_query_vectors_rand_entities(field_name, entities, top_k, nq, search_params=search_param)
547
        inside_query, inside_vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
548
        distance_0 = ip(vecs[0], inside_vecs[0])
549
        distance_1 = ip(vecs[0], inside_vecs[1])
550
        res = connect.search(ip_collection, query)
551
        assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
552
553
    # TODO: distance problem
554 View Code Duplication
    def _test_search_distance_ip_after_index(self, connect, ip_collection, get_simple_index):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
555
        '''
556
        target: search collection, and check the result: distance
557
        method: compare the return distance value with value computed with Inner product
558
        expected: the return distance equals to the computed value
559
        '''
560
        index_type = get_simple_index["index_type"]
561
        nq = 2
562
        entities, ids = init_data(connect, ip_collection)
563
        connect.create_index(ip_collection, field_name, default_index_name, get_simple_index)
564
        search_param = get_search_param(index_type)
565
        query, vecs = gen_query_vectors_rand_entities(field_name, entities, top_k, nq, search_params=search_param)
566
        inside_vecs = entities[-1]["values"]
567
        max_distance = 0
568
        for i in range(nb):
569
            tmp_dis = ip(vecs[0], inside_vecs[i])
570
            if max_distance < tmp_dis:
571
                max_distance = tmp_dis
572
        res = connect.search(ip_collection, query)
573
        assert abs(res[0]._distances[0] - max_distance) <= gen_inaccuracy(res[0]._distances[0])
574
575
    # TODO:
576
    def _test_search_distance_jaccard_flat_index(self, connect, jac_collection):
577
        '''
578
        target: search ip_collection, and check the result: distance
579
        method: compare the return distance value with value computed with Inner product
580
        expected: the return distance equals to the computed value
581
        '''
582
        # from scipy.spatial import distance
583
        nprobe = 512
584
        int_vectors, entities, ids = init_binary_data(connect, jac_collection, nb=2)
585
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, jac_collection, nb=1, insert=False)
586
        distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
587
        distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
588
        res = connect.search(jac_collection, query_entities)
589
        assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon
590
591
    def _test_search_distance_hamming_flat_index(self, connect, ham_collection):
592
        '''
593
        target: search ip_collection, and check the result: distance
594
        method: compare the return distance value with value computed with Inner product
595
        expected: the return distance equals to the computed value
596
        '''
597
        # from scipy.spatial import distance
598
        nprobe = 512
599
        int_vectors, entities, ids = init_binary_data(connect, ham_collection, nb=2)
600
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, ham_collection, nb=1, insert=False)
601
        distance_0 = hamming(query_int_vectors[0], int_vectors[0])
602
        distance_1 = hamming(query_int_vectors[0], int_vectors[1])
603
        res = connect.search(ham_collection, query_entities)
604
        assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
605
606 View Code Duplication
    def _test_search_distance_substructure_flat_index(self, connect, substructure_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
607
        '''
608
        target: search ip_collection, and check the result: distance
609
        method: compare the return distance value with value computed with Inner product
610
        expected: the return distance equals to the computed value
611
        '''
612
        # from scipy.spatial import distance
613
        nprobe = 512
614
        int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
615
        index_type = "FLAT"
616
        index_param = {
617
            "nlist": 16384
618
        }
619
        connect.create_index(substructure_collection, index_type, index_param)
620
        logging.getLogger().info(connect.get_collection_info(substructure_collection))
621
        logging.getLogger().info(connect.get_index_info(substructure_collection))
622
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1, insert=False)
623
        distance_0 = substructure(query_int_vectors[0], int_vectors[0])
624
        distance_1 = substructure(query_int_vectors[0], int_vectors[1])
625
        search_param = get_search_param(index_type)
626
        status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param)
627
        logging.getLogger().info(status)
628
        logging.getLogger().info(result)
629
        assert len(result[0]) == 0
630
631 View Code Duplication
    def _test_search_distance_substructure_flat_index_B(self, connect, substructure_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
632
        '''
633
        target: search ip_collection, and check the result: distance
634
        method: compare the return distance value with value computed with SUB 
635
        expected: the return distance equals to the computed value
636
        '''
637
        # from scipy.spatial import distance
638
        top_k = 3
639
        nprobe = 512
640
        int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
641
        index_type = "FLAT"
642
        index_param = {
643
            "nlist": 16384
644
        }
645
        connect.create_index(substructure_collection, index_type, index_param)
646
        logging.getLogger().info(connect.get_collection_info(substructure_collection))
647
        logging.getLogger().info(connect.get_index_info(substructure_collection))
648
        query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2)
649
        search_param = get_search_param(index_type)
650
        status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param)
651
        logging.getLogger().info(status)
652
        logging.getLogger().info(result) 
653
        assert len(result[0]) == 1
654
        assert len(result[1]) == 1
655
        assert result[0][0].distance <= epsilon
656
        assert result[0][0].id == ids[0]
657
        assert result[1][0].distance <= epsilon
658
        assert result[1][0].id == ids[1]
659
660 View Code Duplication
    def _test_search_distance_superstructure_flat_index(self, connect, superstructure_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
661
        '''
662
        target: search ip_collection, and check the result: distance
663
        method: compare the return distance value with value computed with Inner product
664
        expected: the return distance equals to the computed value
665
        '''
666
        # from scipy.spatial import distance
667
        nprobe = 512
668
        int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
669
        index_type = "FLAT"
670
        index_param = {
671
            "nlist": 16384
672
        }
673
        connect.create_index(superstructure_collection, index_type, index_param)
674
        logging.getLogger().info(connect.get_collection_info(superstructure_collection))
675
        logging.getLogger().info(connect.get_index_info(superstructure_collection))
676
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1, insert=False)
677
        distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
678
        distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
679
        search_param = get_search_param(index_type)
680
        status, result = connect.search(superstructure_collection, top_k, query_vecs, params=search_param)
681
        logging.getLogger().info(status)
682
        logging.getLogger().info(result)
683
        assert len(result[0]) == 0
684
685 View Code Duplication
    def _test_search_distance_superstructure_flat_index_B(self, connect, superstructure_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
686
        '''
687
        target: search ip_collection, and check the result: distance
688
        method: compare the return distance value with value computed with SUPER
689
        expected: the return distance equals to the computed value
690
        '''
691
        # from scipy.spatial import distance
692
        top_k = 3
693
        nprobe = 512
694
        int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
695
        index_type = "FLAT"
696
        index_param = {
697
            "nlist": 16384
698
        }
699
        connect.create_index(superstructure_collection, index_type, index_param)
700
        logging.getLogger().info(connect.get_collection_info(superstructure_collection))
701
        logging.getLogger().info(connect.get_index_info(superstructure_collection))
702
        query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2)
703
        search_param = get_search_param(index_type)
704
        status, result = connect.search(superstructure_collection, top_k, query_vecs, params=search_param)
705
        logging.getLogger().info(status)
706
        logging.getLogger().info(result)
707
        assert len(result[0]) == 2
708
        assert len(result[1]) == 2
709
        assert result[0][0].id in ids
710
        assert result[0][0].distance <= epsilon
711
        assert result[1][0].id in ids
712
        assert result[1][0].distance <= epsilon
713
714 View Code Duplication
    def _test_search_distance_tanimoto_flat_index(self, connect, tanimoto_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
715
        '''
716
        target: search ip_collection, and check the result: distance
717
        method: compare the return distance value with value computed with Inner product
718
        expected: the return distance equals to the computed value
719
        '''
720
        # from scipy.spatial import distance
721
        nprobe = 512
722
        int_vectors, vectors, ids = self.init_binary_data(connect, tanimoto_collection, nb=2)
723
        index_type = "FLAT"
724
        index_param = {
725
            "nlist": 16384
726
        }
727
        connect.create_index(tanimoto_collection, index_type, index_param)
728
        logging.getLogger().info(connect.get_collection_info(tanimoto_collection))
729
        logging.getLogger().info(connect.get_index_info(tanimoto_collection))
730
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, tanimoto_collection, nb=1, insert=False)
731
        distance_0 = tanimoto(query_int_vectors[0], int_vectors[0])
732
        distance_1 = tanimoto(query_int_vectors[0], int_vectors[1])
733
        search_param = get_search_param(index_type)
734
        status, result = connect.search(tanimoto_collection, top_k, query_vecs, params=search_param)
735
        logging.getLogger().info(status)
736
        logging.getLogger().info(result)
737
        assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon
738
739
    @pytest.mark.timeout(30)
740
    def test_search_concurrent_multithreads(self, connect, args):
741
        '''
742
        target: test concurrent search with multiprocessess
743
        method: search with 10 processes, each process uses dependent connection
744
        expected: status ok and the returned vectors should be query_records
745
        '''
746
        nb = 100
747
        top_k = 10
748
        threads_num = 4
749
        threads = []
750
        collection = gen_unique_str(collection_id)
751
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
752
        # create collection
753
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
754
        milvus.create_collection(collection, default_fields)
755
        entities, ids = init_data(milvus, collection)
756
        def search(milvus):
757
            res = connect.search(collection, query)
758
            assert len(res) == 1
759
            assert res[0]._entities[0].id in ids
760
            assert res[0]._distances[0] < epsilon
761
        for i in range(threads_num):
762
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
763
            t = threading.Thread(target=search, args=(milvus, ))
764
            threads.append(t)
765
            t.start()
766
            time.sleep(0.2)
767
        for t in threads:
768
            t.join()
769
770
    @pytest.mark.timeout(30)
771
    def test_search_concurrent_multithreads_single_connection(self, connect, args):
772
        '''
773
        target: test concurrent search with multiprocessess
774
        method: search with 10 processes, each process uses dependent connection
775
        expected: status ok and the returned vectors should be query_records
776
        '''
777
        nb = 100
778
        top_k = 10
779
        threads_num = 4
780
        threads = []
781
        collection = gen_unique_str(collection_id)
782
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
783
        # create collection
784
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
785
        milvus.create_collection(collection, default_fields)
786
        entities, ids = init_data(milvus, collection)
787
        def search(milvus):
788
            res = connect.search(collection, query)
789
            assert len(res) == 1
790
            assert res[0]._entities[0].id in ids
791
            assert res[0]._distances[0] < epsilon
792
        for i in range(threads_num):
793
            t = threading.Thread(target=search, args=(milvus, ))
794
            threads.append(t)
795
            t.start()
796
            time.sleep(0.2)
797
        for t in threads:
798
            t.join()
799
800
    def test_search_multi_collections(self, connect, args):
801
        '''
802
        target: test search multi collections of L2
803
        method: add vectors into 10 collections, and search
804
        expected: search status ok, the length of result
805
        '''
806
        num = 10
807
        top_k = 10
808
        nq = 20
809
        for i in range(num):
810
            collection = gen_unique_str(collection_id+str(i))
811
            connect.create_collection(collection, default_fields)
812
            entities, ids = init_data(connect, collection)
813
            assert len(ids) == nb
814
            query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
815
            res = connect.search(collection, query)
816
            assert len(res) == nq
817
            for i in range(nq):
818
                assert check_id_result(res[i], ids[i])
819
                assert res[i]._distances[0] < epsilon
820
                assert res[i]._distances[1] > epsilon
821
822
"""
823
******************************************************************
824
#  The following cases are used to test `search_vectors` function 
825
#  with invalid collection_name top-k / nprobe / query_range
826
******************************************************************
827
"""
828
829
class TestSearchInvalid(object):
830
831
    """
832
    Test search collection with invalid collection names
833
    """
834
    @pytest.fixture(
835
        scope="function",
836
        params=gen_invalid_strs()
837
    )
838
    def get_collection_name(self, request):
839
        yield request.param
840
841
    @pytest.fixture(
842
        scope="function",
843
        params=gen_invalid_strs()
844
    )
845
    def get_invalid_tag(self, request):
846
        yield request.param
847
848
    @pytest.fixture(
849
        scope="function",
850
        params=gen_invalid_strs()
851
    )
852
    def get_invalid_field(self, request):
853
        yield request.param
854
855
    @pytest.fixture(
856
        scope="function",
857
        params=gen_simple_index()
858
    )
859
    def get_simple_index(self, request, connect):
860
        if str(connect._cmd("mode")) == "CPU":
861
            if request.param["index_type"] in index_cpu_not_support():
862
                pytest.skip("sq8h not support in CPU mode")
863
        return request.param
864
865
    @pytest.mark.level(2)
866
    def test_search_with_invalid_collection(self, connect, get_collection_name):
867
        collection_name = get_collection_name
868
        with pytest.raises(Exception) as e:
869
            res = connect.search(collection_name, query)
870
871
    @pytest.mark.level(1)
872
    def test_search_with_invalid_tag(self, connect, collection):
873
        tag = " "
874
        with pytest.raises(Exception) as e:
875
            res = connect.search(collection, query, partition_tags=tag)
876
877
    @pytest.mark.level(2)
878
    def test_search_with_invalid_field_name(self, connect, collection, get_invalid_field):
879
        fields = [get_invalid_field]
880
        with pytest.raises(Exception) as e:
881
            res = connect.search(collection, query, fields=fields)
882
883
    @pytest.mark.level(1)
884
    def test_search_with_not_existed_field_name(self, connect, collection):
885
        fields = [gen_unique_str("field_name")]
886
        with pytest.raises(Exception) as e:
887
            res = connect.search(collection, query, fields=fields)
888
889
    """
890
    Test search collection with invalid query
891
    """
892
    @pytest.fixture(
893
        scope="function",
894
        params=gen_invalid_ints()
895
    )
896
    def get_top_k(self, request):
897
        yield request.param
898
899
    @pytest.mark.level(1)
900
    def test_search_with_invalid_top_k(self, connect, collection, get_top_k):
901
        '''
902
        target: test search fuction, with the wrong top_k
903
        method: search with top_k
904
        expected: raise an error, and the connection is normal
905
        '''
906
        top_k = get_top_k
907
        query["bool"]["must"][0]["vector"][field_name]["topk"] = top_k
908
        with pytest.raises(Exception) as e:
909
            res = connect.search(collection, query)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable query does not seem to be defined.
Loading history...
910
911
    """
912
    Test search collection with invalid search params
913
    """
914
    @pytest.fixture(
915
        scope="function",
916
        params=gen_invaild_search_params()
917
    )
918
    def get_search_params(self, request):
919
        yield request.param
920
921
    # TODO: This case can all pass, but it's too slow
922 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
923
    def _test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
924
        '''
925
        target: test search fuction, with the wrong nprobe
926
        method: search with nprobe
927
        expected: raise an error, and the connection is normal
928
        '''
929
        search_params = get_search_params
930
        index_type = get_simple_index["index_type"]
931
        entities, ids = init_data(connect, collection)
932
        connect.create_index(collection, field_name, default_index_name, get_simple_index)
933
        if search_params["index_type"] != index_type:
934
            pytest.skip("Skip case")
935
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, 1, search_params=search_params["search_params"])
936
        with pytest.raises(Exception) as e:
937
            res = connect.search(collection, query)
938
939 View Code Duplication
    def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
940
        '''
941
        target: test search fuction, with empty search params
942
        method: search with params
943
        expected: raise an error, and the connection is normal
944
        '''
945
        index_type = get_simple_index["index_type"]
946
        if args["handler"] == "HTTP":
947
            pytest.skip("skip in http mode")
948
        if index_type == "FLAT":
949
            pytest.skip("skip in FLAT index")
950
        entities, ids = init_data(connect, collection)
951
        connect.create_index(collection, field_name, default_index_name, get_simple_index)
952
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, 1, search_params={})
953
        with pytest.raises(Exception) as e:
954
            res = connect.search(collection, query)
955
956
957
def check_id_result(result, id):
958
    limit_in = 5
959
    ids = [entity.id for entity in result]
960
    if len(result) >= limit_in:
961
        return id in ids[:limit_in]
962
    else:
963
        return id in ids
964