Passed
Push — master ( 123201...62841c )
by
unknown
01:58
created

test_search.TestSearchBase.test_search_index_partition_B()   B

Complexity

Conditions 5

Size

Total Lines 27
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 20
nop 6
dl 0
loc 27
rs 8.9332
c 0
b 0
f 0
1
import time
2
import pdb
3
import copy
4
import threading
5
import logging
6
from multiprocessing import Pool, Process
7
import pytest
8
import numpy as np
9
10
from milvus import DataType
11
from utils import *
12
13
dim = 128
14
segment_row_count = 5000
15
top_k_limit = 2048
16
collection_id = "search"
17
tag = "1970-01-01"
18
insert_interval_time = 1.5
19
nb = 6000
20
top_k = 10
21
nprobe = 1
22
epsilon = 0.001
23
field_name = default_float_vec_field_name
24
default_fields = gen_default_fields() 
25
search_param = {"nprobe": 1}
26
entity = gen_entities(1, is_normal=True)
27
raw_vector, binary_entity = gen_binary_entities(1)
28
entities = gen_entities(nb, is_normal=True)
29
raw_vectors, binary_entities = gen_binary_entities(nb)
30
default_query, default_query_vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, 1)
31
32
def init_data(connect, collection, nb=6000, partition_tags=None):
33
    '''
34
    Generate entities and add it in collection
35
    '''
36
    global entities
37
    if nb == 6000:
38
        insert_entities = entities
39
    else:  
40
        insert_entities = gen_entities(nb, is_normal=True)
41
    if partition_tags is None:
42
        ids = connect.insert(collection, insert_entities)
43
    else:
44
        ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
45
    connect.flush([collection])
46
    return insert_entities, ids
47
48
def init_binary_data(connect, collection, nb=6000, insert=True, partition_tags=None):
49
    '''
50
    Generate entities and add it in collection
51
    '''
52
    ids = []
53
    global binary_entities
54
    global raw_vectors
55
    if nb == 6000:
56
        insert_entities = binary_entities
57
        insert_raw_vectors = raw_vectors
58
    else:  
59
        insert_raw_vectors, insert_entities = gen_binary_entities(nb)
60
    if insert is True:
61
        if partition_tags is None:
62
            ids = connect.insert(collection, insert_entities)
63
        else:
64
            ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
65
        connect.flush([collection])
66
    return insert_raw_vectors, insert_entities, ids
67
68
69
class TestSearchBase:
70
71
72
    """
73
    generate valid create_index params
74
    """
75
    @pytest.fixture(
76
        scope="function",
77
        params=gen_index()
78
    )
79
    def get_index(self, request, connect):
80
        if str(connect._cmd("mode")) == "CPU":
81
            if request.param["index_type"] in index_cpu_not_support():
82
                pytest.skip("sq8h not support in CPU mode")
83
        return request.param
84
85
    @pytest.fixture(
86
        scope="function",
87
        params=gen_simple_index()
88
    )
89
    def get_simple_index(self, request, connect):
90
        if str(connect._cmd("mode")) == "CPU":
91
            if request.param["index_type"] in index_cpu_not_support():
92
                pytest.skip("sq8h not support in CPU mode")
93
        return request.param
94
95
    @pytest.fixture(
96
        scope="function",
97
        params=gen_simple_index()
98
    )
99
    def get_jaccard_index(self, request, connect):
100
        logging.getLogger().info(request.param)
101
        if request.param["index_type"] in binary_support():
102
            return request.param
103
        else:
104
            pytest.skip("Skip index Temporary")
105
106
    @pytest.fixture(
107
        scope="function",
108
        params=gen_simple_index()
109
    )
110
    def get_hamming_index(self, request, connect):
111
        logging.getLogger().info(request.param)
112
        if request.param["index_type"] in binary_support():
113
            return request.param
114
        else:
115
            pytest.skip("Skip index Temporary")
116
117
    @pytest.fixture(
118
        scope="function",
119
        params=gen_simple_index()
120
    )
121
    def get_structure_index(self, request, connect):
122
        logging.getLogger().info(request.param)
123
        if request.param["index_type"] == "FLAT":
124
            return request.param
125
        else:
126
            pytest.skip("Skip index Temporary")
127
128
    """
129
    generate top-k params
130
    """
131
    @pytest.fixture(
132
        scope="function",
133
        params=[1, 10, 2049]
134
    )
135
    def get_top_k(self, request):
136
        yield request.param
137
138
    @pytest.fixture(
139
        scope="function",
140
        params=[1, 10, 1100]
141
    )
142
    def get_nq(self, request):
143
        yield request.param
144
145
    def test_search_flat(self, connect, collection, get_top_k, get_nq):
146
        '''
147
        target: test basic search fuction, all the search params is corrent, change top-k value
148
        method: search with the given vectors, check the result
149
        expected: the length of the result is top_k
150
        '''
151
        top_k = get_top_k
152
        nq = get_nq
153
        entities, ids = init_data(connect, collection)
154
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq)
155
        if top_k <= top_k_limit:
156
            res = connect.search(collection, query)
157
            assert len(res[0]) == top_k
158
            assert res[0]._distances[0] <= epsilon
159
            assert check_id_result(res[0], ids[0])
160
        else:
161
            with pytest.raises(Exception) as e:
162
                res = connect.search(collection, query)
163
164
    def test_search_field(self, connect, collection, get_top_k, get_nq):
165
        '''
166
        target: test basic search fuction, all the search params is corrent, change top-k value
167
        method: search with the given vectors, check the result
168
        expected: the length of the result is top_k
169
        '''
170
        top_k = get_top_k
171
        nq = get_nq
172
        entities, ids = init_data(connect, collection)
173
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq)
174
        if top_k <= top_k_limit:
175
            res = connect.search(collection, query, fields=["float_vector"])
176
            assert len(res[0]) == top_k
177
            assert res[0]._distances[0] <= epsilon
178
            assert check_id_result(res[0], ids[0])
179
            # TODO
180
            res = connect.search(collection, query, fields=["float"])
181
            # TODO
182
        else:
183
            with pytest.raises(Exception) as e:
184
                res = connect.search(collection, query)
185
186
    @pytest.mark.level(2)
187
    def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
188
        '''
189
        target: test basic search fuction, all the search params is corrent, test all index params, and build
190
        method: search with the given vectors, check the result
191
        expected: the length of the result is top_k
192
        '''
193
        top_k = get_top_k
194
        nq = get_nq
195
196
        index_type = get_simple_index["index_type"]
197
        if index_type == "IVF_PQ":
198
            pytest.skip("Skip PQ")
199
        entities, ids = init_data(connect, collection)
200
        connect.create_index(collection, field_name, get_simple_index)
201
        search_param = get_search_param(index_type)
202
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
203
        if top_k > top_k_limit:
204
            with pytest.raises(Exception) as e:
205
                res = connect.search(collection, query)
206
        else:
207
            res = connect.search(collection, query)
208
            assert len(res) == nq
209
            assert len(res[0]) >= top_k
210
            assert res[0]._distances[0] < epsilon
211
            assert check_id_result(res[0], ids[0])
212
213
    def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
214
        '''
215
        target: test basic search fuction, all the search params is corrent, test all index params, and build
216
        method: add vectors into collection, search with the given vectors, check the result
217
        expected: the length of the result is top_k, search collection with partition tag return empty
218
        '''
219
        top_k = get_top_k
220
        nq = get_nq
221
222
        index_type = get_simple_index["index_type"]
223
        if index_type == "IVF_PQ":
224
            pytest.skip("Skip PQ")
225
        connect.create_partition(collection, tag)
226
        entities, ids = init_data(connect, collection)
227
        connect.create_index(collection, field_name, get_simple_index)
228
        search_param = get_search_param(index_type)
229
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
230
        if top_k > top_k_limit:
231
            with pytest.raises(Exception) as e:
232
                res = connect.search(collection, query)
233
        else:
234
            res = connect.search(collection, query)
235
            assert len(res) == nq
236
            assert len(res[0]) >= top_k
237
            assert res[0]._distances[0] < epsilon
238
            assert check_id_result(res[0], ids[0])
239
            res = connect.search(collection, query, partition_tags=[tag])
240
            assert len(res) == nq
241
242
    def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq):
243
        '''
244
        target: test basic search fuction, all the search params is corrent, test all index params, and build
245
        method: search with the given vectors, check the result
246
        expected: the length of the result is top_k
247
        '''
248
        top_k = get_top_k
249
        nq = get_nq
250
251
        index_type = get_simple_index["index_type"]
252
        if index_type == "IVF_PQ":
253
            pytest.skip("Skip PQ")
254
        connect.create_partition(collection, tag)
255
        entities, ids = init_data(connect, collection, partition_tags=tag)
256
        connect.create_index(collection, field_name, get_simple_index)
257
        search_param = get_search_param(index_type)
258
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
259
        for tags in [[tag], [tag, "new_tag"]]:
260
            if top_k > top_k_limit:
261
                with pytest.raises(Exception) as e:
262
                    res = connect.search(collection, query, partition_tags=tags)
263
            else:
264
                res = connect.search(collection, query, partition_tags=tags)
265
                assert len(res) == nq
266
                assert len(res[0]) >= top_k
267
                assert res[0]._distances[0] < epsilon
268
                assert check_id_result(res[0], ids[0])
269
270
    @pytest.mark.level(2)
271
    def test_search_index_partition_C(self, connect, collection, get_top_k, get_nq):
272
        '''
273
        target: test basic search fuction, all the search params is corrent, test all index params, and build
274
        method: search with the given vectors and tag (tag name not existed in collection), check the result
275
        expected: error raised
276
        '''
277
        top_k = get_top_k
278
        nq = get_nq
279
        entities, ids = init_data(connect, collection)
280
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq)
281
        if top_k > top_k_limit:
282
            with pytest.raises(Exception) as e:
283
                res = connect.search(collection, query, partition_tags=["new_tag"])
284
        else:
285
            res = connect.search(collection, query, partition_tags=["new_tag"])
286
            assert len(res) == nq
287
            assert len(res[0]) == 0
288
289
    @pytest.mark.level(2)
290
    def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k):
291
        '''
292
        target: test basic search fuction, all the search params is corrent, test all index params, and build
293
        method: search collection with the given vectors and tags, check the result
294
        expected: the length of the result is top_k
295
        '''
296
        top_k = get_top_k
297
        nq = 2
298
        new_tag = "new_tag"
299
        index_type = get_simple_index["index_type"]
300
        if index_type == "IVF_PQ":
301
            pytest.skip("Skip PQ")
302
        connect.create_partition(collection, tag)
303
        connect.create_partition(collection, new_tag)
304
        entities, ids = init_data(connect, collection, partition_tags=tag)
305
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
306
        connect.create_index(collection, field_name, get_simple_index)
307
        search_param = get_search_param(index_type)
308
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
309
        if top_k > top_k_limit:
310
            with pytest.raises(Exception) as e:
311
                res = connect.search(collection, query)
312
        else:
313
            res = connect.search(collection, query)
314
            assert check_id_result(res[0], ids[0])
315
            assert not check_id_result(res[1], new_ids[0])
316
            assert res[0]._distances[0] < epsilon
317
            assert res[1]._distances[0] < epsilon
318
            res = connect.search(collection, query, partition_tags=["new_tag"])
319
            assert res[0]._distances[0] > epsilon
320
            assert res[1]._distances[0] > epsilon
321
322
    # TODO:
323
    @pytest.mark.level(2)
324
    def _test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
325
        '''
326
        target: test basic search fuction, all the search params is corrent, test all index params, and build
327
        method: search collection with the given vectors and tags, check the result
328
        expected: the length of the result is top_k
329
        '''
330
        top_k = get_top_k
331
        nq = 2
332
        tag = "tag"
333
        new_tag = "new_tag"
334
        index_type = get_simple_index["index_type"]
335
        if index_type == "IVF_PQ":
336
            pytest.skip("Skip PQ")
337
        connect.create_partition(collection, tag)
338
        connect.create_partition(collection, new_tag)
339
        entities, ids = init_data(connect, collection, partition_tags=tag)
340
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
341
        connect.create_index(collection, field_name, get_simple_index)
342
        search_param = get_search_param(index_type)
343
        query, vecs = gen_query_vectors_inside_entities(field_name, new_entities, top_k, nq, search_params=search_param)
344
        if top_k > top_k_limit:
345
            with pytest.raises(Exception) as e:
346
                res = connect.search(collection, query)
347
        else:
348
            res = connect.search(collection, query, partition_tags=["(.*)tag"])
349
            assert not check_id_result(res[0], ids[0])
350
            assert check_id_result(res[1], new_ids[0])
351
            assert res[0]._distances[0] > epsilon
352
            assert res[1]._distances[0] < epsilon
353
            res = connect.search(collection, query, partition_tags=["new(.*)"])
354
            assert res[0]._distances[0] > epsilon
355
            assert res[1]._distances[0] < epsilon
356
357
    # 
358
    # test for ip metric
359
    # 
360
    @pytest.mark.level(2)
361
    def test_search_ip_flat(self, connect, ip_collection, get_simple_index, get_top_k, get_nq):
362
        '''
363
        target: test basic search fuction, all the search params is corrent, change top-k value
364
        method: search with the given vectors, check the result
365
        expected: the length of the result is top_k
366
        '''
367
        top_k = get_top_k
368
        nq = get_nq
369
        entities, ids = init_data(connect, ip_collection)
370
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq)
371
        if top_k <= top_k_limit:
372
            res = connect.search(ip_collection, query)
373
            assert len(res[0]) == top_k
374
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
375
            assert check_id_result(res[0], ids[0])
376
        else:
377
            with pytest.raises(Exception) as e:
378
                res = connect.search(ip_collection, query)
379
380
    def test_search_ip_after_index(self, connect, ip_collection, get_simple_index, get_top_k, get_nq):
381
        '''
382
        target: test basic search fuction, all the search params is corrent, test all index params, and build
383
        method: search with the given vectors, check the result
384
        expected: the length of the result is top_k
385
        '''
386
        top_k = get_top_k
387
        nq = get_nq
388
389
        index_type = get_simple_index["index_type"]
390
        if index_type == "IVF_PQ":
391
            pytest.skip("Skip PQ")
392
        entities, ids = init_data(connect, ip_collection)
393
        get_simple_index["metric_type"] = "IP"
394
        connect.create_index(ip_collection, field_name, get_simple_index)
395
        search_param = get_search_param(index_type)
396
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
397
        if top_k > top_k_limit:
398
            with pytest.raises(Exception) as e:
399
                res = connect.search(ip_collection, query)
400
        else:
401
            res = connect.search(ip_collection, query)
402
            assert len(res) == nq
403
            assert len(res[0]) >= top_k
404
            assert check_id_result(res[0], ids[0])
405
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
406
407
    @pytest.mark.level(2)
408
    def test_search_ip_index_partition(self, connect, ip_collection, get_simple_index, get_top_k, get_nq):
409
        '''
410
        target: test basic search fuction, all the search params is corrent, test all index params, and build
411
        method: add vectors into collection, search with the given vectors, check the result
412
        expected: the length of the result is top_k, search collection with partition tag return empty
413
        '''
414
        top_k = get_top_k
415
        nq = get_nq
416
417
        index_type = get_simple_index["index_type"]
418
        if index_type == "IVF_PQ":
419
            pytest.skip("Skip PQ")
420
        connect.create_partition(ip_collection, tag)
421
        entities, ids = init_data(connect, ip_collection)
422
        get_simple_index["metric_type"] = "IP"
423
        connect.create_index(ip_collection, field_name, get_simple_index)
424
        search_param = get_search_param(index_type)
425
        search_param["metric_type"] = "IP"
426
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
427
        if top_k > top_k_limit:
428
            with pytest.raises(Exception) as e:
429
                res = connect.search(ip_collection, query)
430
        else:
431
            res = connect.search(ip_collection, query)
432
            assert len(res) == nq
433
            assert len(res[0]) >= top_k
434
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
435
            assert check_id_result(res[0], ids[0])
436
            res = connect.search(ip_collection, query, partition_tags=[tag])
437
            assert len(res) == nq
438
439
    @pytest.mark.level(2)
440
    def test_search_ip_index_partitions(self, connect, ip_collection, get_simple_index, get_top_k):
441
        '''
442
        target: test basic search fuction, all the search params is corrent, test all index params, and build
443
        method: search ip_collection with the given vectors and tags, check the result
444
        expected: the length of the result is top_k
445
        '''
446
        top_k = get_top_k
447
        nq = 2
448
        new_tag = "new_tag"
449
        index_type = get_simple_index["index_type"]
450
        if index_type == "IVF_PQ":
451
            pytest.skip("Skip PQ")
452
        connect.create_partition(ip_collection, tag)
453
        connect.create_partition(ip_collection, new_tag)
454
        entities, ids = init_data(connect, ip_collection, partition_tags=tag)
455
        new_entities, new_ids = init_data(connect, ip_collection, nb=6001, partition_tags=new_tag)
456
        get_simple_index["metric_type"] = "IP"
457
        connect.create_index(ip_collection, field_name, get_simple_index)
458
        search_param = get_search_param(index_type)
459
        search_param["metric_type"] = "IP"
460
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
461
        if top_k > top_k_limit:
462
            with pytest.raises(Exception) as e:
463
                res = connect.search(ip_collection, query)
464
        else:
465
            res = connect.search(ip_collection, query)
466
            assert check_id_result(res[0], ids[0])
467
            assert not check_id_result(res[1], new_ids[0])
468
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
469
            assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
470
            res = connect.search(ip_collection, query, partition_tags=["new_tag"])
471
            assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0])
472
            # TODO:
473
            # assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
474
475
    @pytest.mark.level(2)
476
    def test_search_without_connect(self, dis_connect, collection):
477
        '''
478
        target: test search vectors without connection
479
        method: use dis connected instance, call search method and check if search successfully
480
        expected: raise exception
481
        '''
482
        with pytest.raises(Exception) as e:
483
            res = dis_connect.search(collection, default_query)
484
485
    def test_search_collection_name_not_existed(self, connect):
486
        '''
487
        target: search collection not existed
488
        method: search with the random collection_name, which is not in db
489
        expected: status not ok
490
        '''
491
        collection_name = gen_unique_str(collection_id)
492
        with pytest.raises(Exception) as e:
493
            res = connect.search(collection_name, default_query)
494
495
    def test_search_distance_l2(self, connect, collection):
496
        '''
497
        target: search collection, and check the result: distance
498
        method: compare the return distance value with value computed with Euclidean
499
        expected: the return distance equals to the computed value
500
        '''
501
        nq = 2
502
        search_param = {"nprobe" : 1}
503
        entities, ids = init_data(connect, collection, nb=nq)
504
        query, vecs = gen_query_vectors_rand_entities(field_name, entities, top_k, nq, search_params=search_param)
505
        inside_query, inside_vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
506
        distance_0 = l2(vecs[0], inside_vecs[0])
507
        distance_1 = l2(vecs[0], inside_vecs[1])
508
        res = connect.search(collection, query)
509
        assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
510
511
    # TODO: distance problem
512
    def _test_search_distance_l2_after_index(self, connect, collection, get_simple_index):
513
        '''
514
        target: search collection, and check the result: distance
515
        method: compare the return distance value with value computed with Inner product
516
        expected: the return distance equals to the computed value
517
        '''
518
        index_type = get_simple_index["index_type"]
519
        nq = 2
520
        entities, ids = init_data(connect, collection)
521
        connect.create_index(collection, field_name, get_simple_index)
522
        search_param = get_search_param(index_type)
523
        query, vecs = gen_query_vectors_rand_entities(field_name, entities, top_k, nq, search_params=search_param)
524
        inside_vecs = entities[-1]["values"]
525
        min_distance = 1.0
526
        for i in range(nb):
527
            tmp_dis = l2(vecs[0], inside_vecs[i])
528
            if min_distance > tmp_dis:
529
                min_distance = tmp_dis
530
        res = connect.search(collection, query)
531
        assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= gen_inaccuracy(res[0]._distances[0])
532
533
    def test_search_distance_ip(self, connect, ip_collection):
534
        '''
535
        target: search ip_collection, and check the result: distance
536
        method: compare the return distance value with value computed with Inner product
537
        expected: the return distance equals to the computed value
538
        '''
539
        nq = 2
540
        search_param = {"nprobe" : 1}
541
        entities, ids = init_data(connect, ip_collection, nb=nq)
542
        query, vecs = gen_query_vectors_rand_entities(field_name, entities, top_k, nq, search_params=search_param)
543
        inside_query, inside_vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
544
        distance_0 = ip(vecs[0], inside_vecs[0])
545
        distance_1 = ip(vecs[0], inside_vecs[1])
546
        res = connect.search(ip_collection, query)
547
        assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
548
549
    # TODO: distance problem
550
    def _test_search_distance_ip_after_index(self, connect, ip_collection, get_simple_index):
551
        '''
552
        target: search collection, and check the result: distance
553
        method: compare the return distance value with value computed with Inner product
554
        expected: the return distance equals to the computed value
555
        '''
556
        index_type = get_simple_index["index_type"]
557
        nq = 2
558
        entities, ids = init_data(connect, ip_collection)
559
        get_simple_index["metric_type"] = "IP"
560
        connect.create_index(ip_collection, field_name, get_simple_index)
561
        search_param = get_search_param(index_type)
562
        search_param["metric_type"] = "IP"
563
        query, vecs = gen_query_vectors_rand_entities(field_name, entities, top_k, nq, search_params=search_param)
564
        inside_vecs = entities[-1]["values"]
565
        max_distance = 0
566
        for i in range(nb):
567
            tmp_dis = ip(vecs[0], inside_vecs[i])
568
            if max_distance < tmp_dis:
569
                max_distance = tmp_dis
570
        res = connect.search(ip_collection, query)
571
        assert abs(res[0]._distances[0] - max_distance) <= gen_inaccuracy(res[0]._distances[0])
572
573
    # TODO:
574
    def _test_search_distance_jaccard_flat_index(self, connect, jac_collection):
575
        '''
576
        target: search ip_collection, and check the result: distance
577
        method: compare the return distance value with value computed with Inner product
578
        expected: the return distance equals to the computed value
579
        '''
580
        # from scipy.spatial import distance
581
        nprobe = 512
582
        int_vectors, entities, ids = init_binary_data(connect, jac_collection, nb=2)
583
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, jac_collection, nb=1, insert=False)
584
        distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
585
        distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
586
        res = connect.search(jac_collection, query_entities)
587
        assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon
588
589
    def _test_search_distance_hamming_flat_index(self, connect, ham_collection):
590
        '''
591
        target: search ip_collection, and check the result: distance
592
        method: compare the return distance value with value computed with Inner product
593
        expected: the return distance equals to the computed value
594
        '''
595
        # from scipy.spatial import distance
596
        nprobe = 512
597
        int_vectors, entities, ids = init_binary_data(connect, ham_collection, nb=2)
598
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, ham_collection, nb=1, insert=False)
599
        distance_0 = hamming(query_int_vectors[0], int_vectors[0])
600
        distance_1 = hamming(query_int_vectors[0], int_vectors[1])
601
        res = connect.search(ham_collection, query_entities)
602
        assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
603
604
    def _test_search_distance_substructure_flat_index(self, connect, substructure_collection):
605
        '''
606
        target: search ip_collection, and check the result: distance
607
        method: compare the return distance value with value computed with Inner product
608
        expected: the return distance equals to the computed value
609
        '''
610
        # from scipy.spatial import distance
611
        nprobe = 512
612
        int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
613
        index_type = "FLAT"
614
        index_param = {
615
            "nlist": 16384,
616
            "metric_type": "SUBSTRUCTURE"
617
        }
618
        connect.create_index(substructure_collection, binary_field_name, index_param)
619
        logging.getLogger().info(connect.get_collection_info(substructure_collection))
620
        logging.getLogger().info(connect.get_index_info(substructure_collection))
621
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1, insert=False)
622
        distance_0 = substructure(query_int_vectors[0], int_vectors[0])
623
        distance_1 = substructure(query_int_vectors[0], int_vectors[1])
624
        search_param = get_search_param(index_type)
625
        status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param)
626
        logging.getLogger().info(status)
627
        logging.getLogger().info(result)
628
        assert len(result[0]) == 0
629
630
    def _test_search_distance_substructure_flat_index_B(self, connect, substructure_collection):
631
        '''
632
        target: search ip_collection, and check the result: distance
633
        method: compare the return distance value with value computed with SUB 
634
        expected: the return distance equals to the computed value
635
        '''
636
        # from scipy.spatial import distance
637
        top_k = 3
638
        nprobe = 512
639
        int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
640
        index_type = "FLAT"
641
        index_param = {
642
            "nlist": 16384,
643
            "metric_type": "SUBSTRUCTURE"
644
        }
645
        connect.create_index(substructure_collection, binary_field_name, index_param)
646
        logging.getLogger().info(connect.get_collection_info(substructure_collection))
647
        logging.getLogger().info(connect.get_index_info(substructure_collection))
648
        query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2)
649
        search_param = get_search_param(index_type)
650
        status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param)
651
        logging.getLogger().info(status)
652
        logging.getLogger().info(result) 
653
        assert len(result[0]) == 1
654
        assert len(result[1]) == 1
655
        assert result[0][0].distance <= epsilon
656
        assert result[0][0].id == ids[0]
657
        assert result[1][0].distance <= epsilon
658
        assert result[1][0].id == ids[1]
659
660
    def _test_search_distance_superstructure_flat_index(self, connect, superstructure_collection):
661
        '''
662
        target: search ip_collection, and check the result: distance
663
        method: compare the return distance value with value computed with Inner product
664
        expected: the return distance equals to the computed value
665
        '''
666
        # from scipy.spatial import distance
667
        nprobe = 512
668
        int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
669
        index_type = "FLAT"
670
        index_param = {
671
            "nlist": 16384,
672
            "metric_type": "SUBSTRUCTURE"
673
        }
674
        connect.create_index(superstructure_collection, binary_field_name, index_param)
675
        logging.getLogger().info(connect.get_collection_info(superstructure_collection))
676
        logging.getLogger().info(connect.get_index_info(superstructure_collection))
677
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1, insert=False)
678
        distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
679
        distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
680
        search_param = get_search_param(index_type)
681
        status, result = connect.search(superstructure_collection, top_k, query_vecs, params=search_param)
682
        logging.getLogger().info(status)
683
        logging.getLogger().info(result)
684
        assert len(result[0]) == 0
685
686
    def _test_search_distance_superstructure_flat_index_B(self, connect, superstructure_collection):
687
        '''
688
        target: search ip_collection, and check the result: distance
689
        method: compare the return distance value with value computed with SUPER
690
        expected: the return distance equals to the computed value
691
        '''
692
        # from scipy.spatial import distance
693
        top_k = 3
694
        nprobe = 512
695
        int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
696
        index_type = "FLAT"
697
        index_param = {
698
            "nlist": 16384,
699
            "metric_type": "SUBSTRUCTURE"
700
        }
701
        connect.create_index(superstructure_collection, binary_field_name, index_param)
702
        logging.getLogger().info(connect.get_collection_info(superstructure_collection))
703
        logging.getLogger().info(connect.get_index_info(superstructure_collection))
704
        query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2)
705
        search_param = get_search_param(index_type)
706
        status, result = connect.search(superstructure_collection, top_k, query_vecs, params=search_param)
707
        logging.getLogger().info(status)
708
        logging.getLogger().info(result)
709
        assert len(result[0]) == 2
710
        assert len(result[1]) == 2
711
        assert result[0][0].id in ids
712
        assert result[0][0].distance <= epsilon
713
        assert result[1][0].id in ids
714
        assert result[1][0].distance <= epsilon
715
716
    def _test_search_distance_tanimoto_flat_index(self, connect, tanimoto_collection):
717
        '''
718
        target: search ip_collection, and check the result: distance
719
        method: compare the return distance value with value computed with Inner product
720
        expected: the return distance equals to the computed value
721
        '''
722
        # from scipy.spatial import distance
723
        nprobe = 512
724
        int_vectors, vectors, ids = self.init_binary_data(connect, tanimoto_collection, nb=2)
725
        index_type = "FLAT"
726
        index_param = {
727
            "nlist": 16384,
728
            "metric_type": "TANIMOTO"
729
        }
730
        connect.create_index(tanimoto_collection, binary_field_name, index_param)
731
        logging.getLogger().info(connect.get_collection_info(tanimoto_collection))
732
        logging.getLogger().info(connect.get_index_info(tanimoto_collection))
733
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, tanimoto_collection, nb=1, insert=False)
734
        distance_0 = tanimoto(query_int_vectors[0], int_vectors[0])
735
        distance_1 = tanimoto(query_int_vectors[0], int_vectors[1])
736
        search_param = get_search_param(index_type)
737
        status, result = connect.search(tanimoto_collection, top_k, query_vecs, params=search_param)
738
        logging.getLogger().info(status)
739
        logging.getLogger().info(result)
740
        assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon
741
742
    @pytest.mark.timeout(30)
743
    def test_search_concurrent_multithreads(self, connect, args):
744
        '''
745
        target: test concurrent search with multiprocessess
746
        method: search with 10 processes, each process uses dependent connection
747
        expected: status ok and the returned vectors should be query_records
748
        '''
749
        nb = 100
750
        top_k = 10
751
        threads_num = 4
752
        threads = []
753
        collection = gen_unique_str(collection_id)
754
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
755
        # create collection
756
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
757
        milvus.create_collection(collection, default_fields)
758
        entities, ids = init_data(milvus, collection)
759
        def search(milvus):
760
            res = connect.search(collection, default_query)
761
            assert len(res) == 1
762
            assert res[0]._entities[0].id in ids
763
            assert res[0]._distances[0] < epsilon
764
        for i in range(threads_num):
765
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
766
            t = threading.Thread(target=search, args=(milvus, ))
767
            threads.append(t)
768
            t.start()
769
            time.sleep(0.2)
770
        for t in threads:
771
            t.join()
772
773
    @pytest.mark.timeout(30)
774
    def test_search_concurrent_multithreads_single_connection(self, connect, args):
775
        '''
776
        target: test concurrent search with multiprocessess
777
        method: search with 10 processes, each process uses dependent connection
778
        expected: status ok and the returned vectors should be query_records
779
        '''
780
        nb = 100
781
        top_k = 10
782
        threads_num = 4
783
        threads = []
784
        collection = gen_unique_str(collection_id)
785
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
786
        # create collection
787
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
788
        milvus.create_collection(collection, default_fields)
789
        entities, ids = init_data(milvus, collection)
790
        def search(milvus):
791
            res = connect.search(collection, default_query)
792
            assert len(res) == 1
793
            assert res[0]._entities[0].id in ids
794
            assert res[0]._distances[0] < epsilon
795
        for i in range(threads_num):
796
            t = threading.Thread(target=search, args=(milvus, ))
797
            threads.append(t)
798
            t.start()
799
            time.sleep(0.2)
800
        for t in threads:
801
            t.join()
802
803
    def test_search_multi_collections(self, connect, args):
804
        '''
805
        target: test search multi collections of L2
806
        method: add vectors into 10 collections, and search
807
        expected: search status ok, the length of result
808
        '''
809
        num = 10
810
        top_k = 10
811
        nq = 20
812
        for i in range(num):
813
            collection = gen_unique_str(collection_id+str(i))
814
            connect.create_collection(collection, default_fields)
815
            entities, ids = init_data(connect, collection)
816
            assert len(ids) == nb
817
            query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, nq, search_params=search_param)
818
            res = connect.search(collection, query)
819
            assert len(res) == nq
820
            for i in range(nq):
821
                assert check_id_result(res[i], ids[i])
822
                assert res[i]._distances[0] < epsilon
823
                assert res[i]._distances[1] > epsilon
824
825
826
class TestSearchDSL(object):
827
828
    """
829
    ******************************************************************
830
    #  The following cases are used to build invalid query expr
831
    ******************************************************************
832
    """
833
834
    # TODO: assert exception
835
    def test_query_no_must(self, connect, collection):
836
        '''
837
        method: build query without must expr
838
        expected: error raised
839
        '''
840
        # entities, ids = init_data(connect, collection)
841
        query = update_query_expr(default_query, keep_old=False):
842
        with pytest.raises(Exception) as e:
843
            res = connect.search(collection, query)
844
845
    # TODO: 
846
    def test_query_no_vector_term_only(self, connect, collection):
847
        '''
848
        method: build query without must expr
849
        expected: error raised
850
        '''
851
        # entities, ids = init_data(connect, collection)
852
        expr = {
853
            "must": [gen_default_term_expr]
854
        }
855
        query = update_query_expr(default_query, keep_old=False, expr=expr):
856
        with pytest.raises(Exception) as e:
857
            res = connect.search(collection, query)
858
859
    def test_query_wrong_format(self, connect, collection):
860
        '''
861
        method: build query without must expr, with wrong expr name
862
        expected: error raised
863
        '''
864
        # entities, ids = init_data(connect, collection)
865
        expr = {
866
            "must1": [gen_default_term_expr]
867
        }
868
        query = update_query_expr(default_query, keep_old=False, expr=expr):
869
        with pytest.raises(Exception) as e:
870
            res = connect.search(collection, query)
871
872
    def test_query_empty(self, connect, collection):
873
        '''
874
        method: search with empty query
875
        expected: error raised
876
        '''
877
        query = {}
878
        with pytest.raises(Exception) as e:
879
            res = connect.search(collection, query)
880
881
    def test_query_with_wrong_format_term(self, connect, collection):
882
        '''
883
        method: build query with wrong term expr
884
        expected: error raised
885
        '''
886
        expr = gen_default_term_expr
887
        expr["term"] = 1
888
        query = update_query_expr(default_query, expr=expr)
889
        with pytest.raises(Exception) as e:
890
            res = connect.search(collection, query)
891
892
893
"""
894
******************************************************************
895
#  The following cases are used to test `search` function 
896
#  with invalid collection_name, or invalid query expr
897
******************************************************************
898
"""
899
900
class TestSearchInvalid(object):
901
902
    """
903
    Test search collection with invalid collection names
904
    """
905
    @pytest.fixture(
906
        scope="function",
907
        params=gen_invalid_strs()
908
    )
909
    def get_collection_name(self, request):
910
        yield request.param
911
912
    @pytest.fixture(
913
        scope="function",
914
        params=gen_invalid_strs()
915
    )
916
    def get_invalid_tag(self, request):
917
        yield request.param
918
919
    @pytest.fixture(
920
        scope="function",
921
        params=gen_invalid_strs()
922
    )
923
    def get_invalid_field(self, request):
924
        yield request.param
925
926
    @pytest.fixture(
927
        scope="function",
928
        params=gen_simple_index()
929
    )
930
    def get_simple_index(self, request, connect):
931
        if str(connect._cmd("mode")) == "CPU":
932
            if request.param["index_type"] in index_cpu_not_support():
933
                pytest.skip("sq8h not support in CPU mode")
934
        return request.param
935
936
    @pytest.mark.level(2)
937
    def test_search_with_invalid_collection(self, connect, get_collection_name):
938
        collection_name = get_collection_name
939
        with pytest.raises(Exception) as e:
940
            res = connect.search(collection_name, default_query)
941
942
    @pytest.mark.level(1)
943
    def test_search_with_invalid_tag(self, connect, collection):
944
        tag = " "
945
        with pytest.raises(Exception) as e:
946
            res = connect.search(collection, default_query, partition_tags=tag)
947
948
    @pytest.mark.level(2)
949
    def test_search_with_invalid_field_name(self, connect, collection, get_invalid_field):
950
        fields = [get_invalid_field]
951
        with pytest.raises(Exception) as e:
952
            res = connect.search(collection, default_query, fields=fields)
953
954
    @pytest.mark.level(1)
955
    def test_search_with_not_existed_field_name(self, connect, collection):
956
        fields = [gen_unique_str("field_name")]
957
        with pytest.raises(Exception) as e:
958
            res = connect.search(collection, default_query, fields=fields)
959
960
    """
961
    Test search collection with invalid query
962
    """
963
    @pytest.fixture(
964
        scope="function",
965
        params=gen_invalid_ints()
966
    )
967
    def get_top_k(self, request):
968
        yield request.param
969
970
    @pytest.mark.level(1)
971
    def test_search_with_invalid_top_k(self, connect, collection, get_top_k):
972
        '''
973
        target: test search fuction, with the wrong top_k
974
        method: search with top_k
975
        expected: raise an error, and the connection is normal
976
        '''
977
        top_k = get_top_k
978
        default_query["bool"]["must"][0]["vector"][field_name]["topk"] = top_k
979
        with pytest.raises(Exception) as e:
980
            res = connect.search(collection, default_query)
981
982
    """
983
    Test search collection with invalid search params
984
    """
985
    @pytest.fixture(
986
        scope="function",
987
        params=gen_invaild_search_params()
988
    )
989
    def get_search_params(self, request):
990
        yield request.param
991
992
    # TODO: This case can all pass, but it's too slow
993
    @pytest.mark.level(2)
994
    def _test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
995
        '''
996
        target: test search fuction, with the wrong nprobe
997
        method: search with nprobe
998
        expected: raise an error, and the connection is normal
999
        '''
1000
        search_params = get_search_params
1001
        index_type = get_simple_index["index_type"]
1002
        entities, ids = init_data(connect, collection)
1003
        connect.create_index(collection, field_name, get_simple_index)
1004
        if search_params["index_type"] != index_type:
1005
            pytest.skip("Skip case")
1006
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, 1, search_params=search_params["search_params"])
1007
        with pytest.raises(Exception) as e:
1008
            res = connect.search(collection, query)
1009
1010
    def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
1011
        '''
1012
        target: test search fuction, with empty search params
1013
        method: search with params
1014
        expected: raise an error, and the connection is normal
1015
        '''
1016
        index_type = get_simple_index["index_type"]
1017
        if args["handler"] == "HTTP":
1018
            pytest.skip("skip in http mode")
1019
        if index_type == "FLAT":
1020
            pytest.skip("skip in FLAT index")
1021
        entities, ids = init_data(connect, collection)
1022
        connect.create_index(collection, field_name, get_simple_index)
1023
        query, vecs = gen_query_vectors_inside_entities(field_name, entities, top_k, 1, search_params={})
1024
        with pytest.raises(Exception) as e:
1025
            res = connect.search(collection, query)
1026
1027
1028
def check_id_result(result, id):
1029
    limit_in = 5
1030
    ids = [entity.id for entity in result]
1031
    if len(result) >= limit_in:
1032
        return id in ids[:limit_in]
1033
    else:
1034
        return id in ids
1035