Passed
Push — master ( 237e90...f9e8c1 )
by
unknown
08:38 queued 01:26
created

TestSearchDSL._test_query_range_valid_ranges()   A

Complexity

Conditions 1

Size

Total Lines 13
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 9
nop 4
dl 0
loc 13
rs 9.95
c 0
b 0
f 0
1
import time
2
import pdb
3
import copy
4
import threading
5
import logging
6
from multiprocessing import Pool, Process
7
import pytest
8
import numpy as np
9
10
from milvus import DataType
11
from utils import *
12
13
dim = 128
14
segment_row_count = 5000
15
top_k_limit = 2048
16
collection_id = "search"
17
tag = "1970-01-01"
18
insert_interval_time = 1.5
19
nb = 6000
20
top_k = 10
21
nq = 1
22
nprobe = 1
23
epsilon = 0.001
24
field_name = default_float_vec_field_name
25
default_fields = gen_default_fields()
26
search_param = {"nprobe": 1}
27
entity = gen_entities(1, is_normal=True)
28
raw_vector, binary_entity = gen_binary_entities(1)
29
entities = gen_entities(nb, is_normal=True)
30
raw_vectors, binary_entities = gen_binary_entities(nb)
31
default_query, default_query_vecs = gen_query_vectors(field_name, entities, top_k, nq)
32
33
34
def init_data(connect, collection, nb=6000, partition_tags=None):
35
    '''
36
    Generate entities and add it in collection
37
    '''
38
    global entities
39
    if nb == 6000:
40
        insert_entities = entities
41
    else:
42
        insert_entities = gen_entities(nb, is_normal=True)
43
    if partition_tags is None:
44
        ids = connect.insert(collection, insert_entities)
45
    else:
46
        ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
47
    connect.flush([collection])
48
    return insert_entities, ids
49
50
51
def init_binary_data(connect, collection, nb=6000, insert=True, partition_tags=None):
52
    '''
53
    Generate entities and add it in collection
54
    '''
55
    ids = []
56
    global binary_entities
57
    global raw_vectors
58
    if nb == 6000:
59
        insert_entities = binary_entities
60
        insert_raw_vectors = raw_vectors
61
    else:
62
        insert_raw_vectors, insert_entities = gen_binary_entities(nb)
63
    if insert is True:
64
        if partition_tags is None:
65
            ids = connect.insert(collection, insert_entities)
66
        else:
67
            ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
68
        connect.flush([collection])
69
    return insert_raw_vectors, insert_entities, ids
70
71
72
class TestSearchBase:
73
    """
74
    generate valid create_index params
75
    """
76
77
    @pytest.fixture(
78
        scope="function",
79
        params=gen_index()
80
    )
81
    def get_index(self, request, connect):
82
        if str(connect._cmd("mode")) == "CPU":
83
            if request.param["index_type"] in index_cpu_not_support():
84
                pytest.skip("sq8h not support in CPU mode")
85
        return request.param
86
87
    @pytest.fixture(
88
        scope="function",
89
        params=gen_simple_index()
90
    )
91
    def get_simple_index(self, request, connect):
92
        if str(connect._cmd("mode")) == "CPU":
93
            if request.param["index_type"] in index_cpu_not_support():
94
                pytest.skip("sq8h not support in CPU mode")
95
        return request.param
96
97
    @pytest.fixture(
98
        scope="function",
99
        params=gen_simple_index()
100
    )
101
    def get_jaccard_index(self, request, connect):
102
        logging.getLogger().info(request.param)
103
        if request.param["index_type"] in binary_support():
104
            return request.param
105
        else:
106
            pytest.skip("Skip index Temporary")
107
108
    @pytest.fixture(
109
        scope="function",
110
        params=gen_simple_index()
111
    )
112
    def get_hamming_index(self, request, connect):
113
        logging.getLogger().info(request.param)
114
        if request.param["index_type"] in binary_support():
115
            return request.param
116
        else:
117
            pytest.skip("Skip index Temporary")
118
119
    @pytest.fixture(
120
        scope="function",
121
        params=gen_simple_index()
122
    )
123
    def get_structure_index(self, request, connect):
124
        logging.getLogger().info(request.param)
125
        if request.param["index_type"] == "FLAT":
126
            return request.param
127
        else:
128
            pytest.skip("Skip index Temporary")
129
130
    """
131
    generate top-k params
132
    """
133
134
    @pytest.fixture(
135
        scope="function",
136
        params=[1, 10, 2049]
137
    )
138
    def get_top_k(self, request):
139
        yield request.param
140
141
    @pytest.fixture(
142
        scope="function",
143
        params=[1, 10, 1100]
144
    )
145
    def get_nq(self, request):
146
        yield request.param
147
148
    def test_search_flat(self, connect, collection, get_top_k, get_nq):
149
        '''
150
        target: test basic search fuction, all the search params is corrent, change top-k value
151
        method: search with the given vectors, check the result
152
        expected: the length of the result is top_k
153
        '''
154
        top_k = get_top_k
155
        nq = get_nq
156
        entities, ids = init_data(connect, collection)
157
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
158
        if top_k <= top_k_limit:
159
            res = connect.search(collection, query)
160
            assert len(res[0]) == top_k
161
            assert res[0]._distances[0] <= epsilon
162
            assert check_id_result(res[0], ids[0])
163
        else:
164
            with pytest.raises(Exception) as e:
165
                res = connect.search(collection, query)
166
167
    def test_search_field(self, connect, collection, get_top_k, get_nq):
168
        '''
169
        target: test basic search fuction, all the search params is corrent, change top-k value
170
        method: search with the given vectors, check the result
171
        expected: the length of the result is top_k
172
        '''
173
        top_k = get_top_k
174
        nq = get_nq
175
        entities, ids = init_data(connect, collection)
176
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
177
        if top_k <= top_k_limit:
178
            res = connect.search(collection, query, fields=["float_vector"])
179
            assert len(res[0]) == top_k
180
            assert res[0]._distances[0] <= epsilon
181
            assert check_id_result(res[0], ids[0])
182
            # TODO
183
            res = connect.search(collection, query, fields=["float"])
184
            # TODO
185
        else:
186
            with pytest.raises(Exception) as e:
187
                res = connect.search(collection, query)
188
189
    @pytest.mark.level(2)
190
    def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
191
        '''
192
        target: test basic search fuction, all the search params is corrent, test all index params, and build
193
        method: search with the given vectors, check the result
194
        expected: the length of the result is top_k
195
        '''
196
        top_k = get_top_k
197
        nq = get_nq
198
199
        index_type = get_simple_index["index_type"]
200
        if index_type == "IVF_PQ":
201
            pytest.skip("Skip PQ")
202
        entities, ids = init_data(connect, collection)
203
        connect.create_index(collection, field_name, get_simple_index)
204
        search_param = get_search_param(index_type)
205
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
206
        if top_k > top_k_limit:
207
            with pytest.raises(Exception) as e:
208
                res = connect.search(collection, query)
209
        else:
210
            res = connect.search(collection, query)
211
            assert len(res) == nq
212
            assert len(res[0]) >= top_k
213
            assert res[0]._distances[0] < epsilon
214
            assert check_id_result(res[0], ids[0])
215
216
    @pytest.mark.level(2)
217
    def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
218
        '''
219
        target: test basic search fuction, all the search params is corrent, test all index params, and build
220
        method: add vectors into collection, search with the given vectors, check the result
221
        expected: the length of the result is top_k, search collection with partition tag return empty
222
        '''
223
        top_k = get_top_k
224
        nq = get_nq
225
226
        index_type = get_simple_index["index_type"]
227
        if index_type == "IVF_PQ":
228
            pytest.skip("Skip PQ")
229
        connect.create_partition(collection, tag)
230
        entities, ids = init_data(connect, collection)
231
        connect.create_index(collection, field_name, get_simple_index)
232
        search_param = get_search_param(index_type)
233
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
234
        if top_k > top_k_limit:
235
            with pytest.raises(Exception) as e:
236
                res = connect.search(collection, query)
237
        else:
238
            res = connect.search(collection, query)
239
            assert len(res) == nq
240
            assert len(res[0]) >= top_k
241
            assert res[0]._distances[0] < epsilon
242
            assert check_id_result(res[0], ids[0])
243
            res = connect.search(collection, query, partition_tags=[tag])
244
            assert len(res) == nq
245
246
    @pytest.mark.level(2)
247
    def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq):
248
        '''
249
        target: test basic search fuction, all the search params is corrent, test all index params, and build
250
        method: search with the given vectors, check the result
251
        expected: the length of the result is top_k
252
        '''
253
        top_k = get_top_k
254
        nq = get_nq
255
256
        index_type = get_simple_index["index_type"]
257
        if index_type == "IVF_PQ":
258
            pytest.skip("Skip PQ")
259
        connect.create_partition(collection, tag)
260
        entities, ids = init_data(connect, collection, partition_tags=tag)
261
        connect.create_index(collection, field_name, get_simple_index)
262
        search_param = get_search_param(index_type)
263
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
264
        for tags in [[tag], [tag, "new_tag"]]:
265
            if top_k > top_k_limit:
266
                with pytest.raises(Exception) as e:
267
                    res = connect.search(collection, query, partition_tags=tags)
268
            else:
269
                res = connect.search(collection, query, partition_tags=tags)
270
                assert len(res) == nq
271
                assert len(res[0]) >= top_k
272
                assert res[0]._distances[0] < epsilon
273
                assert check_id_result(res[0], ids[0])
274
275
    @pytest.mark.level(2)
276
    def test_search_index_partition_C(self, connect, collection, get_top_k, get_nq):
277
        '''
278
        target: test basic search fuction, all the search params is corrent, test all index params, and build
279
        method: search with the given vectors and tag (tag name not existed in collection), check the result
280
        expected: error raised
281
        '''
282
        top_k = get_top_k
283
        nq = get_nq
284
        entities, ids = init_data(connect, collection)
285
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
286
        if top_k > top_k_limit:
287
            with pytest.raises(Exception) as e:
288
                res = connect.search(collection, query, partition_tags=["new_tag"])
289
        else:
290
            res = connect.search(collection, query, partition_tags=["new_tag"])
291
            assert len(res) == nq
292
            assert len(res[0]) == 0
293
294 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
295
    def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k):
296
        '''
297
        target: test basic search fuction, all the search params is corrent, test all index params, and build
298
        method: search collection with the given vectors and tags, check the result
299
        expected: the length of the result is top_k
300
        '''
301
        top_k = get_top_k
302
        nq = 2
303
        new_tag = "new_tag"
304
        index_type = get_simple_index["index_type"]
305
        if index_type == "IVF_PQ":
306
            pytest.skip("Skip PQ")
307
        connect.create_partition(collection, tag)
308
        connect.create_partition(collection, new_tag)
309
        entities, ids = init_data(connect, collection, partition_tags=tag)
310
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
311
        connect.create_index(collection, field_name, get_simple_index)
312
        search_param = get_search_param(index_type)
313
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
314
        if top_k > top_k_limit:
315
            with pytest.raises(Exception) as e:
316
                res = connect.search(collection, query)
317
        else:
318
            res = connect.search(collection, query)
319
            assert check_id_result(res[0], ids[0])
320
            assert not check_id_result(res[1], new_ids[0])
321
            assert res[0]._distances[0] < epsilon
322
            assert res[1]._distances[0] < epsilon
323
            res = connect.search(collection, query, partition_tags=["new_tag"])
324
            assert res[0]._distances[0] > epsilon
325
            assert res[1]._distances[0] > epsilon
326
327
    # TODO:
328 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
329
    def _test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
330
        '''
331
        target: test basic search fuction, all the search params is corrent, test all index params, and build
332
        method: search collection with the given vectors and tags, check the result
333
        expected: the length of the result is top_k
334
        '''
335
        top_k = get_top_k
336
        nq = 2
337
        tag = "tag"
338
        new_tag = "new_tag"
339
        index_type = get_simple_index["index_type"]
340
        if index_type == "IVF_PQ":
341
            pytest.skip("Skip PQ")
342
        connect.create_partition(collection, tag)
343
        connect.create_partition(collection, new_tag)
344
        entities, ids = init_data(connect, collection, partition_tags=tag)
345
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
346
        connect.create_index(collection, field_name, get_simple_index)
347
        search_param = get_search_param(index_type)
348
        query, vecs = gen_query_vectors(field_name, new_entities, top_k, nq, search_params=search_param)
349
        if top_k > top_k_limit:
350
            with pytest.raises(Exception) as e:
351
                res = connect.search(collection, query)
352
        else:
353
            res = connect.search(collection, query, partition_tags=["(.*)tag"])
354
            assert not check_id_result(res[0], ids[0])
355
            assert check_id_result(res[1], new_ids[0])
356
            assert res[0]._distances[0] > epsilon
357
            assert res[1]._distances[0] < epsilon
358
            res = connect.search(collection, query, partition_tags=["new(.*)"])
359
            assert res[0]._distances[0] > epsilon
360
            assert res[1]._distances[0] < epsilon
361
362
    # 
363
    # test for ip metric
364
    # 
365
    @pytest.mark.level(2)
366
    def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq):
367
        '''
368
        target: test basic search fuction, all the search params is corrent, change top-k value
369
        method: search with the given vectors, check the result
370
        expected: the length of the result is top_k
371
        '''
372
        top_k = get_top_k
373
        nq = get_nq
374
        entities, ids = init_data(connect, collection)
375
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP")
376
        if top_k <= top_k_limit:
377
            res = connect.search(collection, query)
378
            assert len(res[0]) == top_k
379
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
380
            assert check_id_result(res[0], ids[0])
381
        else:
382
            with pytest.raises(Exception) as e:
383
                res = connect.search(collection, query)
384
385
    @pytest.mark.level(2)
386
    def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
387
        '''
388
        target: test basic search fuction, all the search params is corrent, test all index params, and build
389
        method: search with the given vectors, check the result
390
        expected: the length of the result is top_k
391
        '''
392
        top_k = get_top_k
393
        nq = get_nq
394
395
        index_type = get_simple_index["index_type"]
396
        if index_type == "IVF_PQ":
397
            pytest.skip("Skip PQ")
398
        entities, ids = init_data(connect, collection)
399
        get_simple_index["metric_type"] = "IP"
400
        connect.create_index(collection, field_name, get_simple_index)
401
        search_param = get_search_param(index_type)
402
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param)
403
        if top_k > top_k_limit:
404
            with pytest.raises(Exception) as e:
405
                res = connect.search(collection, query)
406
        else:
407
            res = connect.search(collection, query)
408
            assert len(res) == nq
409
            assert len(res[0]) >= top_k
410
            assert check_id_result(res[0], ids[0])
411
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
412
413
    @pytest.mark.level(2)
414
    def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
415
        '''
416
        target: test basic search fuction, all the search params is corrent, test all index params, and build
417
        method: add vectors into collection, search with the given vectors, check the result
418
        expected: the length of the result is top_k, search collection with partition tag return empty
419
        '''
420
        top_k = get_top_k
421
        nq = get_nq
422
        metric_type = "IP"
423
        index_type = get_simple_index["index_type"]
424
        if index_type == "IVF_PQ":
425
            pytest.skip("Skip PQ")
426
        connect.create_partition(collection, tag)
427
        entities, ids = init_data(connect, collection)
428
        get_simple_index["metric_type"] = metric_type
429
        connect.create_index(collection, field_name, get_simple_index)
430
        search_param = get_search_param(index_type)
431
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=metric_type,
432
                                        search_params=search_param)
433
        if top_k > top_k_limit:
434
            with pytest.raises(Exception) as e:
435
                res = connect.search(collection, query)
436
        else:
437
            res = connect.search(collection, query)
438
            assert len(res) == nq
439
            assert len(res[0]) >= top_k
440
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
441
            assert check_id_result(res[0], ids[0])
442
            res = connect.search(collection, query, partition_tags=[tag])
443
            assert len(res) == nq
444
445
    @pytest.mark.level(2)
446
    def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
447
        '''
448
        target: test basic search fuction, all the search params is corrent, test all index params, and build
449
        method: search collection with the given vectors and tags, check the result
450
        expected: the length of the result is top_k
451
        '''
452
        top_k = get_top_k
453
        nq = 2
454
        metric_type = "IP"
455
        new_tag = "new_tag"
456
        index_type = get_simple_index["index_type"]
457
        if index_type == "IVF_PQ":
458
            pytest.skip("Skip PQ")
459
        connect.create_partition(collection, tag)
460
        connect.create_partition(collection, new_tag)
461
        entities, ids = init_data(connect, collection, partition_tags=tag)
462
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
463
        get_simple_index["metric_type"] = metric_type
464
        connect.create_index(collection, field_name, get_simple_index)
465
        search_param = get_search_param(index_type)
466
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param)
467
        if top_k > top_k_limit:
468
            with pytest.raises(Exception) as e:
469
                res = connect.search(collection, query)
470
        else:
471
            res = connect.search(collection, query)
472
            assert check_id_result(res[0], ids[0])
473
            assert not check_id_result(res[1], new_ids[0])
474
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
475
            assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
476
            res = connect.search(collection, query, partition_tags=["new_tag"])
477
            assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0])
478
            # TODO:
479
            # assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
480
481
    @pytest.mark.level(2)
482
    def test_search_without_connect(self, dis_connect, collection):
483
        '''
484
        target: test search vectors without connection
485
        method: use dis connected instance, call search method and check if search successfully
486
        expected: raise exception
487
        '''
488
        with pytest.raises(Exception) as e:
489
            res = dis_connect.search(collection, default_query)
490
491
    def test_search_collection_name_not_existed(self, connect):
492
        '''
493
        target: search collection not existed
494
        method: search with the random collection_name, which is not in db
495
        expected: status not ok
496
        '''
497
        collection_name = gen_unique_str(collection_id)
498
        with pytest.raises(Exception) as e:
499
            res = connect.search(collection_name, default_query)
500
501 View Code Duplication
    def test_search_distance_l2(self, connect, collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
502
        '''
503
        target: search collection, and check the result: distance
504
        method: compare the return distance value with value computed with Euclidean
505
        expected: the return distance equals to the computed value
506
        '''
507
        nq = 2
508
        search_param = {"nprobe": 1}
509
        entities, ids = init_data(connect, collection, nb=nq)
510
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param)
511
        inside_query, inside_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
512
        distance_0 = l2(vecs[0], inside_vecs[0])
513
        distance_1 = l2(vecs[0], inside_vecs[1])
514
        res = connect.search(collection, query)
515
        assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
516
517
    # TODO: distance problem
518 View Code Duplication
    def _test_search_distance_l2_after_index(self, connect, collection, get_simple_index):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
519
        '''
520
        target: search collection, and check the result: distance
521
        method: compare the return distance value with value computed with Inner product
522
        expected: the return distance equals to the computed value
523
        '''
524
        index_type = get_simple_index["index_type"]
525
        nq = 2
526
        entities, ids = init_data(connect, collection)
527
        connect.create_index(collection, field_name, get_simple_index)
528
        search_param = get_search_param(index_type)
529
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param)
530
        inside_vecs = entities[-1]["values"]
531
        min_distance = 1.0
532
        for i in range(nb):
533
            tmp_dis = l2(vecs[0], inside_vecs[i])
534
            if min_distance > tmp_dis:
535
                min_distance = tmp_dis
536
        res = connect.search(collection, query)
537
        assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= gen_inaccuracy(res[0]._distances[0])
538
539
    # TODO
540 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
541
    def test_search_distance_ip(self, connect, collection):
542
        '''
543
        target: search collection, and check the result: distance
544
        method: compare the return distance value with value computed with Inner product
545
        expected: the return distance equals to the computed value
546
        '''
547
        nq = 2
548
        metirc_type = "IP"
549
        search_param = {"nprobe": 1}
550
        entities, ids = init_data(connect, collection, nb=nq)
551
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type,
552
                                        search_params=search_param)
553
        inside_query, inside_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
554
        distance_0 = ip(vecs[0], inside_vecs[0])
555
        distance_1 = ip(vecs[0], inside_vecs[1])
556
        res = connect.search(collection, query)
557
        assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
558
559
    # TODO: distance problem
560 View Code Duplication
    def _test_search_distance_ip_after_index(self, connect, collection, get_simple_index):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
561
        '''
562
        target: search collection, and check the result: distance
563
        method: compare the return distance value with value computed with Inner product
564
        expected: the return distance equals to the computed value
565
        '''
566
        index_type = get_simple_index["index_type"]
567
        nq = 2
568
        metirc_type = "IP"
569
        entities, ids = init_data(connect, collection)
570
        get_simple_index["metric_type"] = metirc_type
571
        connect.create_index(collection, field_name, get_simple_index)
572
        search_param = get_search_param(index_type)
573
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type,
574
                                        search_params=search_param)
575
        inside_vecs = entities[-1]["values"]
576
        max_distance = 0
577
        for i in range(nb):
578
            tmp_dis = ip(vecs[0], inside_vecs[i])
579
            if max_distance < tmp_dis:
580
                max_distance = tmp_dis
581
        res = connect.search(collection, query)
582
        assert abs(res[0]._distances[0] - max_distance) <= gen_inaccuracy(res[0]._distances[0])
583
584
    # TODO:
585
    def _test_search_distance_jaccard_flat_index(self, connect, binary_collection):
586
        '''
587
        target: search binary_collection, and check the result: distance
588
        method: compare the return distance value with value computed with Inner product
589
        expected: the return distance equals to the computed value
590
        '''
591
        # from scipy.spatial import distance
592
        nprobe = 512
593
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
594
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
595
        distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
596
        distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
597
        res = connect.search(binary_collection, query_entities)
598
        assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon
599
600
    def _test_search_distance_hamming_flat_index(self, connect, binary_collection):
601
        '''
602
        target: search binary_collection, and check the result: distance
603
        method: compare the return distance value with value computed with Inner product
604
        expected: the return distance equals to the computed value
605
        '''
606
        # from scipy.spatial import distance
607
        nprobe = 512
608
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
609
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
610
        distance_0 = hamming(query_int_vectors[0], int_vectors[0])
611
        distance_1 = hamming(query_int_vectors[0], int_vectors[1])
612
        res = connect.search(binary_collection, query_entities)
613
        assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
614
615 View Code Duplication
    def _test_search_distance_substructure_flat_index(self, connect, binary_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
616
        '''
617
        target: search binary_collection, and check the result: distance
618
        method: compare the return distance value with value computed with Inner product
619
        expected: the return distance equals to the computed value
620
        '''
621
        # from scipy.spatial import distance
622
        nprobe = 512
623
        int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2)
624
        index_type = "FLAT"
625
        index_param = {
626
            "nlist": 16384,
627
            "metric_type": "SUBSTRUCTURE"
628
        }
629
        connect.create_index(binary_collection, binary_field_name, index_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable binary_field_name does not seem to be defined.
Loading history...
630
        logging.getLogger().info(connect.get_collection_info(binary_collection))
631
        logging.getLogger().info(connect.get_index_info(binary_collection))
632
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False)
633
        distance_0 = substructure(query_int_vectors[0], int_vectors[0])
634
        distance_1 = substructure(query_int_vectors[0], int_vectors[1])
635
        search_param = get_search_param(index_type)
636
        status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param)
637
        logging.getLogger().info(status)
638
        logging.getLogger().info(result)
639
        assert len(result[0]) == 0
640
641 View Code Duplication
    def _test_search_distance_substructure_flat_index_B(self, connect, binary_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
642
        '''
643
        target: search binary_collection, and check the result: distance
644
        method: compare the return distance value with value computed with SUB 
645
        expected: the return distance equals to the computed value
646
        '''
647
        # from scipy.spatial import distance
648
        top_k = 3
649
        nprobe = 512
650
        int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2)
651
        index_type = "FLAT"
652
        index_param = {
653
            "nlist": 16384,
654
            "metric_type": "SUBSTRUCTURE"
655
        }
656
        connect.create_index(binary_collection, binary_field_name, index_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable binary_field_name does not seem to be defined.
Loading history...
657
        logging.getLogger().info(connect.get_collection_info(binary_collection))
658
        logging.getLogger().info(connect.get_index_info(binary_collection))
659
        query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2)
660
        search_param = get_search_param(index_type)
661
        status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param)
662
        logging.getLogger().info(status)
663
        logging.getLogger().info(result)
664
        assert len(result[0]) == 1
665
        assert len(result[1]) == 1
666
        assert result[0][0].distance <= epsilon
667
        assert result[0][0].id == ids[0]
668
        assert result[1][0].distance <= epsilon
669
        assert result[1][0].id == ids[1]
670
671 View Code Duplication
    def _test_search_distance_superstructure_flat_index(self, connect, binary_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
672
        '''
673
        target: search binary_collection, and check the result: distance
674
        method: compare the return distance value with value computed with Inner product
675
        expected: the return distance equals to the computed value
676
        '''
677
        # from scipy.spatial import distance
678
        nprobe = 512
679
        int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2)
680
        index_type = "FLAT"
681
        index_param = {
682
            "nlist": 16384,
683
            "metric_type": "SUBSTRUCTURE"
684
        }
685
        connect.create_index(binary_collection, binary_field_name, index_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable binary_field_name does not seem to be defined.
Loading history...
686
        logging.getLogger().info(connect.get_collection_info(binary_collection))
687
        logging.getLogger().info(connect.get_index_info(binary_collection))
688
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False)
689
        distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
690
        distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
691
        search_param = get_search_param(index_type)
692
        status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param)
693
        logging.getLogger().info(status)
694
        logging.getLogger().info(result)
695
        assert len(result[0]) == 0
696
697 View Code Duplication
    def _test_search_distance_superstructure_flat_index_B(self, connect, binary_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
698
        '''
699
        target: search binary_collection, and check the result: distance
700
        method: compare the return distance value with value computed with SUPER
701
        expected: the return distance equals to the computed value
702
        '''
703
        # from scipy.spatial import distance
704
        top_k = 3
705
        nprobe = 512
706
        int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2)
707
        index_type = "FLAT"
708
        index_param = {
709
            "nlist": 16384,
710
            "metric_type": "SUBSTRUCTURE"
711
        }
712
        connect.create_index(binary_collection, binary_field_name, index_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable binary_field_name does not seem to be defined.
Loading history...
713
        logging.getLogger().info(connect.get_collection_info(binary_collection))
714
        logging.getLogger().info(connect.get_index_info(binary_collection))
715
        query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2)
716
        search_param = get_search_param(index_type)
717
        status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param)
718
        logging.getLogger().info(status)
719
        logging.getLogger().info(result)
720
        assert len(result[0]) == 2
721
        assert len(result[1]) == 2
722
        assert result[0][0].id in ids
723
        assert result[0][0].distance <= epsilon
724
        assert result[1][0].id in ids
725
        assert result[1][0].distance <= epsilon
726
727 View Code Duplication
    def _test_search_distance_tanimoto_flat_index(self, connect, binary_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
728
        '''
729
        target: search binary_collection, and check the result: distance
730
        method: compare the return distance value with value computed with Inner product
731
        expected: the return distance equals to the computed value
732
        '''
733
        # from scipy.spatial import distance
734
        nprobe = 512
735
        int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2)
736
        index_type = "FLAT"
737
        index_param = {
738
            "nlist": 16384,
739
            "metric_type": "TANIMOTO"
740
        }
741
        connect.create_index(binary_collection, binary_field_name, index_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable binary_field_name does not seem to be defined.
Loading history...
742
        logging.getLogger().info(connect.get_collection_info(binary_collection))
743
        logging.getLogger().info(connect.get_index_info(binary_collection))
744
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False)
745
        distance_0 = tanimoto(query_int_vectors[0], int_vectors[0])
746
        distance_1 = tanimoto(query_int_vectors[0], int_vectors[1])
747
        search_param = get_search_param(index_type)
748
        status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param)
749
        logging.getLogger().info(status)
750
        logging.getLogger().info(result)
751
        assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon
752
753
    @pytest.mark.timeout(30)
754
    def test_search_concurrent_multithreads(self, connect, args):
755
        '''
756
        target: test concurrent search with multiprocessess
757
        method: search with 10 processes, each process uses dependent connection
758
        expected: status ok and the returned vectors should be query_records
759
        '''
760
        nb = 100
761
        top_k = 10
762
        threads_num = 4
763
        threads = []
764
        collection = gen_unique_str(collection_id)
765
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
766
        # create collection
767
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
768
        milvus.create_collection(collection, default_fields)
769
        entities, ids = init_data(milvus, collection)
770
771
        def search(milvus):
772
            res = connect.search(collection, default_query)
773
            assert len(res) == 1
774
            assert res[0]._entities[0].id in ids
775
            assert res[0]._distances[0] < epsilon
776
777
        for i in range(threads_num):
778
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
779
            t = threading.Thread(target=search, args=(milvus,))
780
            threads.append(t)
781
            t.start()
782
            time.sleep(0.2)
783
        for t in threads:
784
            t.join()
785
786
    @pytest.mark.timeout(30)
787
    def test_search_concurrent_multithreads_single_connection(self, connect, args):
788
        '''
789
        target: test concurrent search with multiprocessess
790
        method: search with 10 processes, each process uses dependent connection
791
        expected: status ok and the returned vectors should be query_records
792
        '''
793
        nb = 100
794
        top_k = 10
795
        threads_num = 4
796
        threads = []
797
        collection = gen_unique_str(collection_id)
798
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
799
        # create collection
800
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
801
        milvus.create_collection(collection, default_fields)
802
        entities, ids = init_data(milvus, collection)
803
804
        def search(milvus):
805
            res = connect.search(collection, default_query)
806
            assert len(res) == 1
807
            assert res[0]._entities[0].id in ids
808
            assert res[0]._distances[0] < epsilon
809
810
        for i in range(threads_num):
811
            t = threading.Thread(target=search, args=(milvus,))
812
            threads.append(t)
813
            t.start()
814
            time.sleep(0.2)
815
        for t in threads:
816
            t.join()
817
818
    def test_search_multi_collections(self, connect, args):
819
        '''
820
        target: test search multi collections of L2
821
        method: add vectors into 10 collections, and search
822
        expected: search status ok, the length of result
823
        '''
824
        num = 10
825
        top_k = 10
826
        nq = 20
827
        for i in range(num):
828
            collection = gen_unique_str(collection_id + str(i))
829
            connect.create_collection(collection, default_fields)
830
            entities, ids = init_data(connect, collection)
831
            assert len(ids) == nb
832
            query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
833
            res = connect.search(collection, query)
834
            assert len(res) == nq
835
            for i in range(nq):
836
                assert check_id_result(res[i], ids[i])
837
                assert res[i]._distances[0] < epsilon
838
                assert res[i]._distances[1] > epsilon
839
840
841
class TestSearchDSL(object):
842
    """
843
    ******************************************************************
844
    #  The following cases are used to build invalid query expr
845
    ******************************************************************
846
    """
847
848
    def test_query_no_must(self, connect, collection):
849
        '''
850
        method: build query without must expr
851
        expected: error raised
852
        '''
853
        # entities, ids = init_data(connect, collection)
854
        query = update_query_expr(default_query, keep_old=False)
855
        with pytest.raises(Exception) as e:
856
            res = connect.search(collection, query)
857
858
    def test_query_no_vector_term_only(self, connect, collection):
859
        '''
860
        method: build query without must expr
861
        expected: error raised
862
        '''
863
        # entities, ids = init_data(connect, collection)
864
        expr = {
865
            "must": [gen_default_term_expr]
866
        }
867
        query = update_query_expr(default_query, keep_old=False, expr=expr)
868
        with pytest.raises(Exception) as e:
869
            res = connect.search(collection, query)
870
871
    def test_query_vector_only(self, connect, collection):
872
        entities, ids = init_data(connect, collection)
873
        res = connect.search(collection, default_query)
874
        assert len(res) == nq
875
        assert len(res[0]) == top_k
876
877
    def test_query_wrong_format(self, connect, collection):
878
        '''
879
        method: build query without must expr, with wrong expr name
880
        expected: error raised
881
        '''
882
        # entities, ids = init_data(connect, collection)
883
        expr = {
884
            "must1": [gen_default_term_expr]
885
        }
886
        query = update_query_expr(default_query, keep_old=False, expr=expr)
887
        with pytest.raises(Exception) as e:
888
            res = connect.search(collection, query)
889
890
    def test_query_empty(self, connect, collection):
891
        '''
892
        method: search with empty query
893
        expected: error raised
894
        '''
895
        query = {}
896
        with pytest.raises(Exception) as e:
897
            res = connect.search(collection, query)
898
899
    """
900
    ******************************************************************
901
    #  The following cases are used to build valid query expr
902
    ******************************************************************
903
    """
904
905
    @pytest.mark.level(2)
906
    def test_query_term_value_not_in(self, connect, collection):
907
        '''
908
        method: build query with vector and term expr, with no term can be filtered
909
        expected: filter pass
910
        '''
911
        entities, ids = init_data(connect, collection)
912
        expr = {
913
            "must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[100000])]}
914
        query = update_query_expr(default_query, expr=expr)
915
        res = connect.search(collection, query)
916
        assert len(res) == nq
917
        assert len(res[0]) == 0
918
        # TODO:
919
920
    @pytest.mark.level(2)
921
    def test_query_term_value_all_in(self, connect, collection):
922
        '''
923
        method: build query with vector and term expr, with all term can be filtered
924
        expected: filter pass
925
        '''
926
        entities, ids = init_data(connect, collection)
927
        expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1])]}
928
        query = update_query_expr(default_query, expr=expr)
929
        res = connect.search(collection, query)
930
        assert len(res) == nq
931
        assert len(res[0]) == 1
932
        # TODO:
933
934 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
935
    def test_query_term_values_not_in(self, connect, collection):
936
        '''
937
        method: build query with vector and term expr, with no term can be filtered
938
        expected: filter pass
939
        '''
940
        entities, ids = init_data(connect, collection)
941
        expr = {"must": [gen_default_vector_expr(default_query),
942
                         gen_default_term_expr(values=[i for i in range(100000, 100010)])]}
943
        query = update_query_expr(default_query, expr=expr)
944
        res = connect.search(collection, query)
945
        assert len(res) == nq
946
        assert len(res[0]) == 0
947
        # TODO:
948
949
    def test_query_term_values_all_in(self, connect, collection):
950
        '''
951
        method: build query with vector and term expr, with all term can be filtered
952
        expected: filter pass
953
        '''
954
        entities, ids = init_data(connect, collection)
955
        expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr()]}
956
        query = update_query_expr(default_query, expr=expr)
957
        res = connect.search(collection, query)
958
        assert len(res) == nq
959
        assert len(res[0]) == top_k
960
        # TODO:
961
962 View Code Duplication
    def test_query_term_values_parts_in(self, connect, collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
963
        '''
964
        method: build query with vector and term expr, with parts of term can be filtered
965
        expected: filter pass
966
        '''
967
        entities, ids = init_data(connect, collection)
968
        expr = {"must": [gen_default_vector_expr(default_query),
969
                         gen_default_term_expr(values=[i for i in range(nb // 2, nb + nb // 2)])]}
970
        query = update_query_expr(default_query, expr=expr)
971
        res = connect.search(collection, query)
972
        assert len(res) == nq
973
        assert len(res[0]) == top_k
974
        # TODO:
975
976 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
977
    def test_query_term_values_repeat(self, connect, collection):
978
        '''
979
        method: build query with vector and term expr, with the same values
980
        expected: filter pass
981
        '''
982
        entities, ids = init_data(connect, collection)
983
        expr = {
984
            "must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1 for i in range(1, nb)])]}
985
        query = update_query_expr(default_query, expr=expr)
986
        res = connect.search(collection, query)
987
        assert len(res) == nq
988
        assert len(res[0]) == 1
989
        # TODO:
990
991
    def test_query_term_value_empty(self, connect, collection):
992
        '''
993
        method: build query with term value empty
994
        expected: return null
995
        '''
996
        expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[])]}
997
        query = update_query_expr(default_query, expr=expr)
998
        res = connect.search(collection, query)
999
        assert len(res) == nq
1000
        assert len(res[0]) == 0
1001
1002
    """
1003
    ******************************************************************
1004
    #  The following cases are used to build invalid term query expr
1005
    ******************************************************************
1006
    """
1007
1008
    # TODO
1009
    @pytest.mark.level(2)
1010
    def test_query_term_key_error(self, connect, collection):
1011
        '''
1012
        method: build query with term key error
1013
        expected: Exception raised
1014
        '''
1015
        expr = {"must": [gen_default_vector_expr(default_query),
1016
                         gen_default_term_expr(keyword="terrm", values=[i for i in range(nb // 2)])]}
1017
        query = update_query_expr(default_query, expr=expr)
1018
        with pytest.raises(Exception) as e:
1019
            res = connect.search(collection, query)
1020
1021
    @pytest.fixture(
1022
        scope="function",
1023
        params=gen_invalid_term()
1024
    )
1025
    def get_invalid_term(self, request):
1026
        return request.param
1027
1028
    # TODO
1029
    @pytest.mark.level(2)
1030
    def test_query_term_wrong_format(self, connect, collection, get_invalid_term):
1031
        '''
1032
        method: build query with wrong format term
1033
        expected: Exception raised
1034
        '''
1035
        entities, ids = init_data(connect, collection)
1036
        term = get_invalid_term
1037
        expr = {"must": [gen_default_vector_expr(default_query), term]}
1038
        query = update_query_expr(default_query, expr=expr)
1039
        with pytest.raises(Exception) as e:
1040
            res = connect.search(collection, query)
1041
1042
    # TODO
1043
    @pytest.mark.level(2)
1044
    def test_query_term_field_named_term(self, connect, collection):
1045
        '''
1046
        method: build query with field named "term"
1047
        expected: error raised
1048
        '''
1049
        term_fields = add_field_default(default_fields, field_name="term")
1050
        collection_term = gen_unique_str("term")
1051
        connect.create_collection(collection_term, term_fields)
1052
        term_entities = add_field(entities, field_name="term")
1053
        ids = connect.insert(collection_term, term_entities)
1054
        assert len(ids) == nb
1055
        connect.flush([collection_term])
1056
        count = connect.count_entities(collection_term)
1057
        assert count == nb
1058
        term_param = {"term": {"term": {"values": [i for i in range(nb // 2)]}}}
1059
        expr = {"must": [gen_default_vector_expr(default_query),
1060
                         term_param]}
1061
        query = update_query_expr(default_query, expr=expr)
1062
        res = connect.search(collection_term, query)
1063
        assert len(res) == nq
1064
        assert len(res[0]) == top_k
1065
        connect.drop_collection(collection_term)
1066
1067
    """
1068
    ******************************************************************
1069
    #  The following cases are used to build valid range query expr
1070
    ******************************************************************
1071
    """
1072
1073
    # TODO
1074
    def test_query_range_key_error(self, connect, collection):
1075
        '''
1076
        method: build query with range key error
1077
        expected: Exception raised
1078
        '''
1079
        range = gen_default_range_expr(keyword="ranges")
1080
        expr = {"must": [gen_default_vector_expr(default_query), range]}
1081
        query = update_query_expr(default_query, expr=expr)
1082
        with pytest.raises(Exception) as e:
1083
            res = connect.search(collection, query)
1084
1085
    @pytest.fixture(
1086
        scope="function",
1087
        params=gen_invalid_range()
1088
    )
1089
    def get_invalid_range(self, request):
1090
        return request.param
1091
1092
    # TODO
1093
    @pytest.mark.level(2)
1094
    def test_query_range_wrong_format(self, connect, collection, get_invalid_range):
1095
        '''
1096
        method: build query with wrong format range
1097
        expected: Exception raised
1098
        '''
1099
        entities, ids = init_data(connect, collection)
1100
        range = get_invalid_range
1101
        expr = {"must": [gen_default_vector_expr(default_query), range]}
1102
        query = update_query_expr(default_query, expr=expr)
1103
        with pytest.raises(Exception) as e:
1104
            res = connect.search(collection, query)
1105
1106
    @pytest.fixture(
1107
        scope="function",
1108
        params=gen_valid_ranges()
1109
    )
1110
    def get_valid_ranges(self, request):
1111
        return request.param
1112
1113
    # TODO:
1114
    def _test_query_range_valid_ranges(self, connect, collection, get_valid_ranges):
1115
        '''
1116
        method: build query with valid ranges
1117
        expected: pass
1118
        '''
1119
        entities, ids = init_data(connect, collection)
1120
        ranges = get_valid_ranges
1121
        range = gen_default_range_expr(ranges=ranges)
1122
        expr = {"must": [gen_default_vector_expr(default_query), range]}
1123
        query = update_query_expr(default_query, expr=expr)
1124
        res = connect.search(collection, query)
1125
        assert len(res) == nq
1126
        assert len(res[0]) == top_k
1127
1128
1129
class TestSearchDSLBools(object):
1130
    """
1131
    ******************************************************************
1132
    #  The following cases are used to build invalid query expr
1133
    ******************************************************************
1134
    """
1135
1136
    def test_query_no_bool(self, connect, collection):
1137
        '''
1138
        method: build query without bool expr
1139
        expected: error raised
1140
        '''
1141
        expr = {"bool1": {}}
1142
        with pytest.raises(Exception) as e:
1143
            res = connect.search(collection, query)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable query does not seem to be defined.
Loading history...
1144
1145
    def test_query_should_only_term(self, connect, collection):
1146
        '''
1147
        method: build query without must, with should.term instead
1148
        expected: error raised
1149
        '''
1150
        expr = {"should": gen_default_term_expr}
1151
        query = update_query_expr(default_query, keep_old=False, expr=expr)
1152
        with pytest.raises(Exception) as e:
1153
            res = connect.search(collection, query)
1154
1155
    def test_query_should_only_vector(self, connect, collection):
1156
        '''
1157
        method: build query without must, with should.vector instead
1158
        expected: error raised
1159
        '''
1160
        expr = {"should": default_query["bool"]["must"]}
1161
        query = update_query_expr(default_query, keep_old=False, expr=expr)
1162
        with pytest.raises(Exception) as e:
1163
            res = connect.search(collection, query)
1164
1165
    def test_query_must_not_only_term(self, connect, collection):
1166
        '''
1167
        method: build query without must, with must_not.term instead
1168
        expected: error raised
1169
        '''
1170
        expr = {"must_not": gen_default_term_expr}
1171
        query = update_query_expr(default_query, keep_old=False, expr=expr)
1172
        with pytest.raises(Exception) as e:
1173
            res = connect.search(collection, query)
1174
1175
    def test_query_must_not_vector(self, connect, collection):
1176
        '''
1177
        method: build query without must, with must_not.vector instead
1178
        expected: error raised
1179
        '''
1180
        expr = {"must_not": default_query["bool"]["must"]}
1181
        query = update_query_expr(default_query, keep_old=False, expr=expr)
1182
        with pytest.raises(Exception) as e:
1183
            res = connect.search(collection, query)
1184
1185
    def test_query_must_should(self, connect, collection):
1186
        '''
1187
        method: build query must, and with should.term
1188
        expected: error raised
1189
        '''
1190
        expr = {"should": gen_default_term_expr}
1191
        query = update_query_expr(default_query, keep_old=True, expr=expr)
1192
        with pytest.raises(Exception) as e:
1193
            res = connect.search(collection, query)
1194
1195
1196
"""
1197
******************************************************************
1198
#  The following cases are used to test `search` function 
1199
#  with invalid collection_name, or invalid query expr
1200
******************************************************************
1201
"""
1202
1203
1204
class TestSearchInvalid(object):
1205
    """
1206
    Test search collection with invalid collection names
1207
    """
1208
1209
    @pytest.fixture(
1210
        scope="function",
1211
        params=gen_invalid_strs()
1212
    )
1213
    def get_collection_name(self, request):
1214
        yield request.param
1215
1216
    @pytest.fixture(
1217
        scope="function",
1218
        params=gen_invalid_strs()
1219
    )
1220
    def get_invalid_tag(self, request):
1221
        yield request.param
1222
1223
    @pytest.fixture(
1224
        scope="function",
1225
        params=gen_invalid_strs()
1226
    )
1227
    def get_invalid_field(self, request):
1228
        yield request.param
1229
1230
    @pytest.fixture(
1231
        scope="function",
1232
        params=gen_simple_index()
1233
    )
1234
    def get_simple_index(self, request, connect):
1235
        if str(connect._cmd("mode")) == "CPU":
1236
            if request.param["index_type"] in index_cpu_not_support():
1237
                pytest.skip("sq8h not support in CPU mode")
1238
        return request.param
1239
1240
    @pytest.mark.level(2)
1241
    def test_search_with_invalid_collection(self, connect, get_collection_name):
1242
        collection_name = get_collection_name
1243
        with pytest.raises(Exception) as e:
1244
            res = connect.search(collection_name, default_query)
1245
1246
    @pytest.mark.level(1)
1247
    def test_search_with_invalid_tag(self, connect, collection):
1248
        tag = " "
1249
        with pytest.raises(Exception) as e:
1250
            res = connect.search(collection, default_query, partition_tags=tag)
1251
1252
    @pytest.mark.level(2)
1253
    def test_search_with_invalid_field_name(self, connect, collection, get_invalid_field):
1254
        fields = [get_invalid_field]
1255
        with pytest.raises(Exception) as e:
1256
            res = connect.search(collection, default_query, fields=fields)
1257
1258
    @pytest.mark.level(1)
1259
    def test_search_with_not_existed_field_name(self, connect, collection):
1260
        fields = [gen_unique_str("field_name")]
1261
        with pytest.raises(Exception) as e:
1262
            res = connect.search(collection, default_query, fields=fields)
1263
1264
    """
1265
    Test search collection with invalid query
1266
    """
1267
1268
    @pytest.fixture(
1269
        scope="function",
1270
        params=gen_invalid_ints()
1271
    )
1272
    def get_top_k(self, request):
1273
        yield request.param
1274
1275
    @pytest.mark.level(1)
1276
    def test_search_with_invalid_top_k(self, connect, collection, get_top_k):
1277
        '''
1278
        target: test search fuction, with the wrong top_k
1279
        method: search with top_k
1280
        expected: raise an error, and the connection is normal
1281
        '''
1282
        top_k = get_top_k
1283
        default_query["bool"]["must"][0]["vector"][field_name]["topk"] = top_k
1284
        with pytest.raises(Exception) as e:
1285
            res = connect.search(collection, default_query)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable default_query does not seem to be defined.
Loading history...
1286
1287
    """
1288
    Test search collection with invalid search params
1289
    """
1290
1291
    @pytest.fixture(
1292
        scope="function",
1293
        params=gen_invaild_search_params()
1294
    )
1295
    def get_search_params(self, request):
1296
        yield request.param
1297
1298
    # TODO: This case can all pass, but it's too slow
1299 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1300
    def _test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
1301
        '''
1302
        target: test search fuction, with the wrong nprobe
1303
        method: search with nprobe
1304
        expected: raise an error, and the connection is normal
1305
        '''
1306
        search_params = get_search_params
1307
        index_type = get_simple_index["index_type"]
1308
        entities, ids = init_data(connect, collection)
1309
        connect.create_index(collection, field_name, get_simple_index)
1310
        if search_params["index_type"] != index_type:
1311
            pytest.skip("Skip case")
1312
        query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params=search_params["search_params"])
1313
        with pytest.raises(Exception) as e:
1314
            res = connect.search(collection, query)
1315
1316 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1317
    def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
1318
        '''
1319
        target: test search fuction, with empty search params
1320
        method: search with params
1321
        expected: raise an error, and the connection is normal
1322
        '''
1323
        index_type = get_simple_index["index_type"]
1324
        if args["handler"] == "HTTP":
1325
            pytest.skip("skip in http mode")
1326
        if index_type == "FLAT":
1327
            pytest.skip("skip in FLAT index")
1328
        entities, ids = init_data(connect, collection)
1329
        connect.create_index(collection, field_name, get_simple_index)
1330
        query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params={})
1331
        with pytest.raises(Exception) as e:
1332
            res = connect.search(collection, query)
1333
1334
1335
def check_id_result(result, id):
1336
    limit_in = 5
1337
    ids = [entity.id for entity in result]
1338
    if len(result) >= limit_in:
1339
        return id in ids[:limit_in]
1340
    else:
1341
        return id in ids
1342